추론 실행 추가

This commit is contained in:
2026-02-11 20:21:25 +09:00
parent 35767adba1
commit 1249a80da5
21 changed files with 1049 additions and 3 deletions

View File

@@ -0,0 +1,48 @@
package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/train")
public class TrainApiController {
private final TrainJobService trainJobService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@RequestMapping("/run/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.train.dto;
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
public class ModelTrainJobQueuedEvent {
private final Long jobId;
public ModelTrainJobQueuedEvent(Long jobId) {
this.jobId = jobId;
}
public Long getJobId() {
return jobId;
}
}

View File

@@ -0,0 +1,82 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunRequest {
// ========================
// 기본
// ========================
private String datasetFolder;
private String outputFolder;
private String inputSize;
private String cropSize;
private Integer batchSize;
private String gpuIds;
private Integer gpus;
private Double learningRate;
private String backbone;
private Integer epochs;
// ========================
// Data
// ========================
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// ========================
// Model Architecture
// ========================
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String classWeight;
private String decoderChannels;
// ========================
// Loss & Optimization
// ========================
private Double weightDecay;
private Double layerDecayRate;
private Integer ignoreIndex;
private Boolean ddpFindUnusedParams;
private Integer numLayers;
// ========================
// Evaluation
// ========================
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// ========================
// Augmentation
// ========================
private Double rotProb;
private String rotDegree;
private Double flipProb;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// ========================
// 실행 타임아웃
// ========================
private Integer timeoutSeconds;
}

View File

@@ -0,0 +1,20 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
/** 학습 실행 결과 반환 객체 */
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunResult {
private String jobId;
private String containerName;
private int exitCode;
private String status;
private String logs;
}

View File

@@ -0,0 +1,230 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Service
public class DockerTrainService {
// 실행할 Docker 이미지명
@Value("${train.docker.image}")
private String image;
// 학습 요청 데이터가 위치한 호스트 디렉토리
@Value("${train.docker.requestDir}")
private String requestDir;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
// 컨테이너 이름 prefix
@Value("${train.docker.containerPrefix}")
private String containerPrefix;
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
@Value("${train.docker.shmSize:16g}")
private String shmSize;
// IPC host 사용 여부
@Value("${train.docker.ipcHost:true}")
private boolean ipcHost;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
// 실행 식별용 jobId 생성
String jobId = UUID.randomUUID().toString().substring(0, 8);
// 컨테이너 이름 생성 (중복 방지 목적)
String containerName = containerPrefix + "-" + jobId;
// docker run 명령어 조립
List<String> cmd = buildDockerRunCommand(containerName, req);
// 프로세스 실행
ProcessBuilder pb = new ProcessBuilder(cmd);
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
pb.redirectErrorStream(true);
Process p = pb.start();
// 실행 로그 수집
StringBuilder log = new StringBuilder();
try (BufferedReader br =
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
log.append(line).append('\n');
}
}
// 지정된 timeout 내에 종료 대기
int timeoutSeconds = 7200; // 기본 2시간
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
if (!finished) {
// 타임아웃 발생 시 컨테이너 강제 제거
killContainer(containerName);
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
}
// 종료 코드 확인 (0=정상)
int exit = p.exitValue();
return new TrainRunResult(
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
}
/**
* docker run 명령어 리스트 구성 - 환경변수 설정 - GPU 옵션 설정 - 볼륨 마운트 - 컨테이너 내부 python 실행 명령 구성 - 요청값이
* null/blank면 해당 옵션은 "아예 생략"
*/
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName);
// 실행 종료 시 자동 삭제
c.add("--rm");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// GPU 전체 사용
c.add("--gpus");
c.add("all");
// IPC host 사용 여부
if (ipcHost) {
c.add("--ipc=host");
}
// 공유메모리 설정
c.add("--shm-size=" + shmSize);
// 메모리 관련 ulimit 설정
c.add("--ulimit");
c.add("memlock=-1");
c.add("--ulimit");
c.add("stack=67108864");
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add(requestDir + ":/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
// 표준입력 유지 (-it 대신 -i만 사용)
c.add("-i");
// 사용할 이미지
c.add(image);
// ===== 컨테이너 내부 실행 명령 =====
c.add("python");
c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
// ===== Data =====
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
addArg(c, "--val-num-workers", req.getValNumWorkers());
addArg(c, "--test-num-workers", req.getTestNumWorkers());
addArg(c, "--train-shuffle", req.getTrainShuffle());
addArg(c, "--train-persistent", req.getTrainPersistent());
addArg(c, "--val-persistent", req.getValPersistent());
// ===== Model Architecture =====
addArg(c, "--drop-path-rate", req.getDropPathRate());
addArg(c, "--frozen-stages", req.getFrozenStages());
addArg(c, "--neck-policy", req.getNeckPolicy());
addArg(c, "--class-weight", req.getClassWeight());
addArg(c, "--decoder-channels", req.getDecoderChannels());
// ===== Loss & Optimization =====
addArg(c, "--weight-decay", req.getWeightDecay());
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
addArg(c, "--ignore-index", req.getIgnoreIndex());
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
addArg(c, "--num-layers", req.getNumLayers());
// ===== Evaluation =====
addArg(c, "--metrics", req.getMetrics());
addArg(c, "--save-best", req.getSaveBest());
addArg(c, "--save-best-rule", req.getSaveBestRule());
addArg(c, "--val-interval", req.getValInterval());
addArg(c, "--log-interval", req.getLogInterval());
addArg(c, "--vis-interval", req.getVisInterval());
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta());
return c;
}
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
private void addArg(List<String> c, String key, Object value) {
if (value == null) return;
String s = String.valueOf(value).trim();
if (s.isEmpty()) return;
c.add(key);
c.add(s);
}
/** 컨테이너 강제 종료 및 제거 */
private void killContainer(String containerName) {
try {
new ProcessBuilder("docker", "rm", "-f", containerName)
.redirectErrorStream(true)
.start()
.waitFor(10, TimeUnit.SECONDS);
} catch (Exception ignored) {
}
}
}

View File

@@ -0,0 +1,119 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
public Long getModelIdByUuid(UUID uuid) {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/** 실행 예약 (QUEUE 등록) */
@Transactional
public Long enqueue(Long modelId) {
// 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId);
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
if (trainRunRequest == null) {
throw new IllegalArgumentException("Model not found: " + modelId);
}
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
@SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 재시작 버튼
*
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
*/
@Transactional
public Long restart(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
throw new IllegalStateException("이미 진행중입니다.");
}
var lastJob =
modelTrainJobCoreService
.findLatestByModelId(modelId)
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId,
nextAttemptNo,
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 중단 버튼
*
* <p>- job 상태 CANCELED - master 상태 STOPPED
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long attemptId = master.getCurrentAttemptId();
if (attemptId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
modelTrainJobCoreService.markCanceled(attemptId);
modelTrainMngCoreService.markStopped(modelId);
}
}

View File

@@ -0,0 +1,87 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
@Component
@RequiredArgsConstructor
public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
Long jobId = event.getJobId(); // record면 event.jobId()
ModelTrainJobEntity job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
Long modelId = job.getModelId();
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
// req가 null이면 실패 처리
if (req == null) {
modelTrainJobCoreService.markFailed(
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
return;
}
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
// logPath/lockedBy는 너 환경에 맞게
String logPath = null;
String lockedBy = "TRAIN_WORKER";
// RUNNING 표시
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
try {
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
// 아래 "추천 수정" 참고.
TrainRunResult result = dockerTrainService.runTrainSync(req);
if (result.getExitCode() == 0) {
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
} else {
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
} catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
if (paramsJson == null || paramsJson.isEmpty()) {
return null;
}
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
}
}