추론 실행 추가
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
|
||||
@RequiredArgsConstructor
|
||||
@RestController
|
||||
@RequestMapping("/api/train")
|
||||
public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "실행 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@RequestMapping("/run/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.enqueue(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
|
||||
public class ModelTrainJobQueuedEvent {
|
||||
|
||||
private final Long jobId;
|
||||
|
||||
public ModelTrainJobQueuedEvent(Long jobId) {
|
||||
this.jobId = jobId;
|
||||
}
|
||||
|
||||
public Long getJobId() {
|
||||
return jobId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunRequest {
|
||||
|
||||
// ========================
|
||||
// 기본
|
||||
// ========================
|
||||
private String datasetFolder;
|
||||
private String outputFolder;
|
||||
private String inputSize;
|
||||
private String cropSize;
|
||||
private Integer batchSize;
|
||||
private String gpuIds;
|
||||
private Integer gpus;
|
||||
private Double learningRate;
|
||||
private String backbone;
|
||||
private Integer epochs;
|
||||
|
||||
// ========================
|
||||
// Data
|
||||
// ========================
|
||||
private Integer trainNumWorkers;
|
||||
private Integer valNumWorkers;
|
||||
private Integer testNumWorkers;
|
||||
private Boolean trainShuffle;
|
||||
private Boolean trainPersistent;
|
||||
private Boolean valPersistent;
|
||||
|
||||
// ========================
|
||||
// Model Architecture
|
||||
// ========================
|
||||
private Double dropPathRate;
|
||||
private Integer frozenStages;
|
||||
private String neckPolicy;
|
||||
private String classWeight;
|
||||
private String decoderChannels;
|
||||
|
||||
// ========================
|
||||
// Loss & Optimization
|
||||
// ========================
|
||||
private Double weightDecay;
|
||||
private Double layerDecayRate;
|
||||
private Integer ignoreIndex;
|
||||
private Boolean ddpFindUnusedParams;
|
||||
private Integer numLayers;
|
||||
|
||||
// ========================
|
||||
// Evaluation
|
||||
// ========================
|
||||
private String metrics;
|
||||
private String saveBest;
|
||||
private String saveBestRule;
|
||||
private Integer valInterval;
|
||||
private Integer logInterval;
|
||||
private Integer visInterval;
|
||||
|
||||
// ========================
|
||||
// Augmentation
|
||||
// ========================
|
||||
private Double rotProb;
|
||||
private String rotDegree;
|
||||
private Double flipProb;
|
||||
private Double exchangeProb;
|
||||
private Integer brightnessDelta;
|
||||
private String contrastRange;
|
||||
private String saturationRange;
|
||||
private Integer hueDelta;
|
||||
|
||||
// ========================
|
||||
// 실행 타임아웃
|
||||
// ========================
|
||||
private Integer timeoutSeconds;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
/** 학습 실행 결과 반환 객체 */
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunResult {
|
||||
|
||||
private String jobId;
|
||||
private String containerName;
|
||||
private int exitCode;
|
||||
private String status;
|
||||
private String logs;
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@Value("${train.docker.image}")
|
||||
private String image;
|
||||
|
||||
// 학습 요청 데이터가 위치한 호스트 디렉토리
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
// 컨테이너 이름 prefix
|
||||
@Value("${train.docker.containerPrefix}")
|
||||
private String containerPrefix;
|
||||
|
||||
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
|
||||
@Value("${train.docker.shmSize:16g}")
|
||||
private String shmSize;
|
||||
|
||||
// IPC host 사용 여부
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
|
||||
|
||||
// 실행 식별용 jobId 생성
|
||||
String jobId = UUID.randomUUID().toString().substring(0, 8);
|
||||
|
||||
// 컨테이너 이름 생성 (중복 방지 목적)
|
||||
String containerName = containerPrefix + "-" + jobId;
|
||||
|
||||
// docker run 명령어 조립
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
|
||||
// 프로세스 실행
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
|
||||
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
// 실행 로그 수집
|
||||
StringBuilder log = new StringBuilder();
|
||||
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
log.append(line).append('\n');
|
||||
}
|
||||
}
|
||||
|
||||
// 지정된 timeout 내에 종료 대기
|
||||
int timeoutSeconds = 7200; // 기본 2시간
|
||||
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
// 타임아웃 발생 시 컨테이너 강제 제거
|
||||
killContainer(containerName);
|
||||
|
||||
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
|
||||
}
|
||||
|
||||
// 종료 코드 확인 (0=정상)
|
||||
int exit = p.exitValue();
|
||||
|
||||
return new TrainRunResult(
|
||||
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* docker run 명령어 리스트 구성 - 환경변수 설정 - GPU 옵션 설정 - 볼륨 마운트 - 컨테이너 내부 python 실행 명령 구성 - 요청값이
|
||||
* null/blank면 해당 옵션은 "아예 생략"
|
||||
*/
|
||||
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName);
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
|
||||
// 환경변수 설정
|
||||
c.add("-e");
|
||||
c.add("OPENCV_LOG_LEVEL=ERROR");
|
||||
c.add("-e");
|
||||
c.add("NCCL_DEBUG=INFO");
|
||||
c.add("-e");
|
||||
c.add("NCCL_IB_DISABLE=1");
|
||||
c.add("-e");
|
||||
c.add("NCCL_P2P_DISABLE=0");
|
||||
c.add("-e");
|
||||
c.add("NCCL_SOCKET_IFNAME=eth0");
|
||||
|
||||
// GPU 전체 사용
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
|
||||
// IPC host 사용 여부
|
||||
if (ipcHost) {
|
||||
c.add("--ipc=host");
|
||||
}
|
||||
|
||||
// 공유메모리 설정
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
// 메모리 관련 ulimit 설정
|
||||
c.add("--ulimit");
|
||||
c.add("memlock=-1");
|
||||
c.add("--ulimit");
|
||||
c.add("stack=67108864");
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
// 표준입력 유지 (-it 대신 -i만 사용)
|
||||
c.add("-i");
|
||||
|
||||
// 사용할 이미지
|
||||
c.add(image);
|
||||
|
||||
// ===== 컨테이너 내부 실행 명령 =====
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||
|
||||
// ===== 기본 파라미터 =====
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
addArg(c, "--batch-size", req.getBatchSize());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
||||
// addArg(c, "--gpus", req.getGpus());
|
||||
addArg(c, "--lr", req.getLearningRate());
|
||||
addArg(c, "--backbone", req.getBackbone());
|
||||
addArg(c, "--epochs", req.getEpochs());
|
||||
|
||||
// ===== Data =====
|
||||
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
|
||||
addArg(c, "--val-num-workers", req.getValNumWorkers());
|
||||
addArg(c, "--test-num-workers", req.getTestNumWorkers());
|
||||
addArg(c, "--train-shuffle", req.getTrainShuffle());
|
||||
addArg(c, "--train-persistent", req.getTrainPersistent());
|
||||
addArg(c, "--val-persistent", req.getValPersistent());
|
||||
|
||||
// ===== Model Architecture =====
|
||||
addArg(c, "--drop-path-rate", req.getDropPathRate());
|
||||
addArg(c, "--frozen-stages", req.getFrozenStages());
|
||||
addArg(c, "--neck-policy", req.getNeckPolicy());
|
||||
addArg(c, "--class-weight", req.getClassWeight());
|
||||
addArg(c, "--decoder-channels", req.getDecoderChannels());
|
||||
|
||||
// ===== Loss & Optimization =====
|
||||
addArg(c, "--weight-decay", req.getWeightDecay());
|
||||
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
|
||||
addArg(c, "--ignore-index", req.getIgnoreIndex());
|
||||
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
|
||||
addArg(c, "--num-layers", req.getNumLayers());
|
||||
|
||||
// ===== Evaluation =====
|
||||
addArg(c, "--metrics", req.getMetrics());
|
||||
addArg(c, "--save-best", req.getSaveBest());
|
||||
addArg(c, "--save-best-rule", req.getSaveBestRule());
|
||||
addArg(c, "--val-interval", req.getValInterval());
|
||||
addArg(c, "--log-interval", req.getLogInterval());
|
||||
addArg(c, "--vis-interval", req.getVisInterval());
|
||||
|
||||
// ===== Augmentation =====
|
||||
addArg(c, "--rot-prob", req.getRotProb());
|
||||
addArg(c, "--rot-degree", req.getRotDegree());
|
||||
addArg(c, "--flip-prob", req.getFlipProb());
|
||||
addArg(c, "--exchange-prob", req.getExchangeProb());
|
||||
addArg(c, "--brightness-delta", req.getBrightnessDelta());
|
||||
addArg(c, "--contrast-range", req.getContrastRange());
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
|
||||
private void addArg(List<String> c, String key, Object value) {
|
||||
if (value == null) return;
|
||||
String s = String.valueOf(value).trim();
|
||||
if (s.isEmpty()) return;
|
||||
c.add(key);
|
||||
c.add(s);
|
||||
}
|
||||
|
||||
/** 컨테이너 강제 종료 및 제거 */
|
||||
private void killContainer(String containerName) {
|
||||
try {
|
||||
new ProcessBuilder("docker", "rm", "-f", containerName)
|
||||
.redirectErrorStream(true)
|
||||
.start()
|
||||
.waitFor(10, TimeUnit.SECONDS);
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class TrainJobService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
public Long getModelIdByUuid(UUID uuid) {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
/** 실행 예약 (QUEUE 등록) */
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||
|
||||
if (trainRunRequest == null) {
|
||||
throw new IllegalArgumentException("Model not found: " + modelId);
|
||||
}
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 재시작 버튼
|
||||
*
|
||||
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
||||
*/
|
||||
@Transactional
|
||||
public Long restart(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
|
||||
throw new IllegalStateException("이미 진행중입니다.");
|
||||
}
|
||||
|
||||
var lastJob =
|
||||
modelTrainJobCoreService
|
||||
.findLatestByModelId(modelId)
|
||||
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId,
|
||||
nextAttemptNo,
|
||||
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
|
||||
ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 버튼
|
||||
*
|
||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||
*
|
||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long attemptId = master.getCurrentAttemptId();
|
||||
if (attemptId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
modelTrainJobCoreService.markCanceled(attemptId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class TrainJobWorker {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Async
|
||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||
public void handle(ModelTrainJobQueuedEvent event) {
|
||||
|
||||
Long jobId = event.getJobId(); // record면 event.jobId()
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
Long modelId = job.getModelId();
|
||||
|
||||
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
|
||||
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
|
||||
// req가 null이면 실패 처리
|
||||
if (req == null) {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
|
||||
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
|
||||
return;
|
||||
}
|
||||
|
||||
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
|
||||
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
|
||||
|
||||
// logPath/lockedBy는 너 환경에 맞게
|
||||
String logPath = null;
|
||||
String lockedBy = "TRAIN_WORKER";
|
||||
|
||||
// RUNNING 표시
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
|
||||
|
||||
try {
|
||||
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
|
||||
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
|
||||
// 아래 "추천 수정" 참고.
|
||||
TrainRunResult result = dockerTrainService.runTrainSync(req);
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
|
||||
} else {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null || paramsJson.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user