diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java index f2440a6..350a248 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; +import com.kamco.cd.training.train.dto.ModelTrainJobDto; import java.time.ZonedDateTime; import java.util.Map; import java.util.Optional; @@ -20,12 +21,12 @@ public class ModelTrainJobCoreService { return modelTrainJobRepository.findMaxAttemptNo(modelId); } - public Optional findLatestByModelId(Long modelId) { - return modelTrainJobRepository.findLatestByModelId(modelId); + public Optional findLatestByModelId(Long modelId) { + return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto); } - public Optional findById(Long jobId) { - return modelTrainJobRepository.findById(jobId); + public Optional findById(Long jobId) { + return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto); } /** QUEUED Job 생성 */ @@ -95,7 +96,7 @@ public class ModelTrainJobCoreService { .findById(jobId) .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); - job.setStatusCd("CANCELED"); + job.setStatusCd("STOPPED"); job.setFinishedDttm(ZonedDateTime.now()); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index d3028fb..1098bcb 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional; @Service @RequiredArgsConstructor public class ModelTrainMngCoreService { + private final ModelMngRepository modelMngRepository; private final ModelDatasetRepository modelDatasetRepository; private final ModelDatasetMappRepository modelDatasetMapRepository; @@ -323,7 +324,7 @@ public class ModelTrainMngCoreService { master.setStatusCd(TrainStatusType.COMPLETED.getId()); } - /** 오류 처리(옵션) - Worker가 실패 시 호출 */ + /** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ @Transactional public void markError(Long modelId, String errorMessage) { ModelMasterEntity master = @@ -332,7 +333,25 @@ public class ModelTrainMngCoreService { .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.ERROR.getId()); + master.setStep1State(TrainStatusType.ERROR.getId()); master.setLastError(errorMessage); + master.setUpdatedUid(userUtil.getId()); + master.setUpdatedDttm(ZonedDateTime.now()); + } + + /** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ + @Transactional + public void markStep2Error(Long modelId, String errorMessage) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setStatusCd(TrainStatusType.ERROR.getId()); + master.setStep2State(TrainStatusType.ERROR.getId()); + master.setLastError(errorMessage); + master.setUpdatedUid(userUtil.getId()); + master.setUpdatedDttm(ZonedDateTime.now()); } @Transactional @@ -358,4 +377,58 @@ public class ModelTrainMngCoreService { public TrainRunRequest findTrainRunRequest(Long modelId) { return modelMngRepository.findTrainRunRequest(modelId); } + + public void markStep1InProgress(Long modelId, Long jobId) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); + entity.setStep1StrtDttm(ZonedDateTime.now()); + entity.setStep1State(TrainStatusType.IN_PROGRESS.getId()); + entity.setCurrentAttemptId(jobId); + entity.setUpdatedDttm(ZonedDateTime.now()); + entity.setUpdatedUid(userUtil.getId()); + } + + public void markStep2InProgress(Long modelId, Long jobId) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); + entity.setStep2StrtDttm(ZonedDateTime.now()); + entity.setStep2State(TrainStatusType.IN_PROGRESS.getId()); + entity.setCurrentAttemptId(jobId); + entity.setUpdatedDttm(ZonedDateTime.now()); + entity.setUpdatedUid(userUtil.getId()); + } + + public void markStep1Success(Long modelId) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + entity.setStatusCd(TrainStatusType.COMPLETED.getId()); + entity.setStep1State(TrainStatusType.COMPLETED.getId()); + entity.setStep1EndDttm(ZonedDateTime.now()); + entity.setUpdatedDttm(ZonedDateTime.now()); + entity.setUpdatedUid(userUtil.getId()); + } + + public void markStep2Success(Long modelId) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + entity.setStatusCd(TrainStatusType.COMPLETED.getId()); + entity.setStep2State(TrainStatusType.COMPLETED.getId()); + entity.setStep2EndDttm(ZonedDateTime.now()); + entity.setUpdatedDttm(ZonedDateTime.now()); + entity.setUpdatedUid(userUtil.getId()); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java index 504d2c4..23c11e0 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.postgres.entity; +import com.kamco.cd.training.train.dto.ModelTrainJobDto; import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.GeneratedValue; @@ -76,4 +77,19 @@ public class ModelTrainJobEntity { @Size(max = 100) @Column(name = "locked_by", length = 100) private String lockedBy; + + public ModelTrainJobDto toDto() { + return new ModelTrainJobDto( + this.id, + this.modelId, + this.attemptNo, + this.statusCd, + this.exitCode, + this.errorMessage, + this.containerName, + this.paramsJson, + this.queuedDttm, + this.startedDttm, + this.finishedDttm); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index 94a578e..c540a70 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { modelHyperParamEntity.contrastRange, modelHyperParamEntity.saturationRange, modelHyperParamEntity.hueDelta, - Expressions.nullExpression(Integer.class))) + Expressions.nullExpression(Integer.class), + Expressions.nullExpression(String.class))) .from(modelMasterEntity) .leftJoin(modelHyperParamEntity) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java index 7b9cf1d..ee79523 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java @@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom { int findMaxAttemptNo(Long modelId); Optional findLatestByModelId(Long modelId); - - Optional pickQueuedForUpdate(); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java index 9fa924a..cf74017 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java @@ -1,34 +1,43 @@ package com.kamco.cd.training.postgres.repository.train; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity; import com.querydsl.jpa.impl.JPAQueryFactory; import jakarta.persistence.EntityManager; import java.util.Optional; -import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Repository; @Repository -@RequiredArgsConstructor public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom { - private final EntityManager em; + private final JPAQueryFactory queryFactory; - private JPAQueryFactory queryFactory() { - return new JPAQueryFactory(em); + public ModelTrainJobRepositoryImpl(EntityManager em) { + this.queryFactory = new JPAQueryFactory(em); } + /** modelId의 attempt_no 최대값. (없으면 0) */ @Override public int findMaxAttemptNo(Long modelId) { - return 0; + QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; + + Integer max = + queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne(); + + return max != null ? max : 0; } + /** + * modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장 + * 단순. + */ @Override public Optional findLatestByModelId(Long modelId) { - return Optional.empty(); - } + QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; - @Override - public Optional pickQueuedForUpdate() { - return Optional.empty(); + ModelTrainJobEntity job = + queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst(); + + return Optional.ofNullable(job); } } diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java index 8c8bb94..4e888fe 100644 --- a/src/main/java/com/kamco/cd/training/train/TrainApiController.java +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.train; import com.kamco.cd.training.config.api.ApiResponseDto; +import com.kamco.cd.training.train.service.TestJobService; import com.kamco.cd.training.train.service.TrainJobService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -12,6 +13,7 @@ import io.swagger.v3.oas.annotations.tags.Tag; import java.util.UUID; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -22,6 +24,7 @@ import org.springframework.web.bind.annotation.RestController; public class TrainApiController { private final TrainJobService trainJobService; + private final TestJobService testJobService; @Operation(summary = "학습 실행", description = "학습 실행 API") @ApiResponses( @@ -36,7 +39,7 @@ public class TrainApiController { @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) - @RequestMapping("/run/{uuid}") + @PostMapping("/run/{uuid}") public ApiResponseDto run( @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @PathVariable @@ -45,4 +48,120 @@ public class TrainApiController { trainJobService.enqueue(modelId); return ApiResponseDto.ok("ok"); } + + @Operation(summary = "학습 재실행", description = "학습 재실행 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "재실행 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/restart/{uuid}") + public ApiResponseDto restart( + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + Long jobId = trainJobService.restart(modelId); + return ApiResponseDto.ok("ok"); + } + + @Operation(summary = "학습 이어하기", description = "학습 이어하기 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "이어하기 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/resume/{uuid}") + public ApiResponseDto resume( + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + Long jobId = trainJobService.resume(modelId); + return ApiResponseDto.ok("ok"); + } + + @Operation(summary = "학습 취소", description = "학습 취소 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "취소 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/cancel/{uuid}") + public ApiResponseDto cancel( + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + trainJobService.cancel(modelId); + return ApiResponseDto.ok("ok"); + } + + @Operation(summary = "test 실행", description = "test 실행 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "test 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/test/run/{epoch}/{uuid}") + public ApiResponseDto run( + @Parameter(description = "best 에폭", example = "1") @PathVariable int epoch, + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + testJobService.enqueue(modelId, uuid, epoch); + return ApiResponseDto.ok("ok"); + } + + @Operation(summary = "학습 취소", description = "학습 취소 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "취소 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/test/cancel/{uuid}") + public ApiResponseDto cancelTest( + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + testJobService.cancel(modelId); + return ApiResponseDto.ok("ok"); + } } diff --git a/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java new file mode 100644 index 0000000..fe621eb --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java @@ -0,0 +1,16 @@ +package com.kamco.cd.training.train.dto; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@AllArgsConstructor +@NoArgsConstructor +public class EvalRunRequest { + private String uuid; + private int epoch; // best_changed_fscore_epoch_1.pth + private Integer timeoutSeconds; +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java new file mode 100644 index 0000000..f9d0004 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java @@ -0,0 +1,23 @@ +package com.kamco.cd.training.train.dto; + +import java.time.ZonedDateTime; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class ModelTrainJobDto { + + private Long id; + private Long modelId; + private Integer attemptNo; + private String statusCd; + private Integer exitCode; + private String errorMessage; + private String containerName; + private Map paramsJson; + private ZonedDateTime queuedDttm; + private ZonedDateTime startedDttm; + private ZonedDateTime finishedDttm; +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java index 4fc7a70..a3b63d8 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java @@ -79,4 +79,5 @@ public class TrainRunRequest { // 실행 타임아웃 // ======================== private Integer timeoutSeconds; + private String resumeFrom; } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index cfd6364..02a1443 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.train.service; +import com.kamco.cd.training.train.dto.EvalRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunResult; import java.io.BufferedReader; @@ -7,7 +8,6 @@ import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import java.util.UUID; import java.util.concurrent.TimeUnit; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -40,53 +40,72 @@ public class DockerTrainService { private boolean ipcHost; /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ - public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception { + public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception { - // 실행 식별용 jobId 생성 - String jobId = UUID.randomUUID().toString().substring(0, 8); - - // 컨테이너 이름 생성 (중복 방지 목적) - String containerName = containerPrefix + "-" + jobId; - - // docker run 명령어 조립 List cmd = buildDockerRunCommand(containerName, req); - // 프로세스 실행 ProcessBuilder pb = new ProcessBuilder(cmd); - - // stderr를 stdout으로 합쳐서 한 스트림으로 처리 pb.redirectErrorStream(true); Process p = pb.start(); - // 실행 로그 수집 + // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) StringBuilder log = new StringBuilder(); + Thread logThread = + new Thread( + () -> { + try (BufferedReader br = + new BufferedReader( + new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = br.readLine()) != null) { + synchronized (log) { + log.append(line).append('\n'); + } + } + } catch (Exception ignored) { + } + }, + "train-log-" + containerName); - try (BufferedReader br = - new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + logThread.setDaemon(true); + logThread.start(); - String line; - while ((line = br.readLine()) != null) { - log.append(line).append('\n'); - } - } - - // 지정된 timeout 내에 종료 대기 - int timeoutSeconds = 7200; // 기본 2시간 + int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200; boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS); if (!finished) { - // 타임아웃 발생 시 컨테이너 강제 제거 + // docker run 프로세스도 같이 끊어야 readLine이 풀림 + p.destroy(); + if (!p.waitFor(2, TimeUnit.SECONDS)) { + p.destroyForcibly(); + } killContainer(containerName); - return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString()); + String logs; + synchronized (log) { + logs = log.toString(); + } + + return new TrainRunResult( + null, // jobId (없으면 null) + containerName, + -1, + "TIMEOUT", + logs); } - // 종료 코드 확인 (0=정상) int exit = p.exitValue(); - return new TrainRunResult( - jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString()); + // 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용) + logThread.join(500); + + String logs; + synchronized (log) { + logs = log.toString(); + } + + return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs); } /** @@ -205,6 +224,7 @@ public class DockerTrainService { addArg(c, "--saturation-range", req.getSaturationRange()); addArg(c, "--hue-delta", req.getHueDelta()); + addArg(c, "--resume-from", req.getResumeFrom()); return c; } @@ -218,7 +238,7 @@ public class DockerTrainService { } /** 컨테이너 강제 종료 및 제거 */ - private void killContainer(String containerName) { + public void killContainer(String containerName) { try { new ProcessBuilder("docker", "rm", "-f", containerName) .redirectErrorStream(true) @@ -227,4 +247,100 @@ public class DockerTrainService { } catch (Exception ignored) { } } + + public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception { + + List cmd = buildDockerEvalCommand(containerName, req); + + ProcessBuilder pb = new ProcessBuilder(cmd); + pb.redirectErrorStream(true); + + Process p = pb.start(); + + StringBuilder log = new StringBuilder(); + Thread logThread = + new Thread( + () -> { + try (BufferedReader br = + new BufferedReader( + new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = br.readLine()) != null) { + synchronized (log) { + log.append(line).append('\n'); + } + } + } catch (Exception ignored) { + } + }); + + logThread.setDaemon(true); + logThread.start(); + + int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200; + boolean finished = p.waitFor(timeout, TimeUnit.SECONDS); + + if (!finished) { + p.destroyForcibly(); + killContainer(containerName); + + String logs; + synchronized (log) { + logs = log.toString(); + } + + return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs); + } + + int exit = p.exitValue(); + logThread.join(500); + + String logs; + synchronized (log) { + logs = log.toString(); + } + + return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs); + } + + private List buildDockerEvalCommand(String containerName, EvalRunRequest req) { + + String uuid = req.getUuid(); + Integer epoch = req.getEpoch(); + if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required"); + if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0"); + + String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth"; + + List c = new ArrayList<>(); + + c.add("docker"); + c.add("run"); + c.add("--name"); + c.add(containerName); + c.add("--rm"); + + c.add("--gpus"); + c.add("all"); + if (ipcHost) c.add("--ipc=host"); + c.add("--shm-size=" + shmSize); + + c.add("-v"); + c.add(requestDir + ":/data"); + c.add("-v"); + c.add(responseDir + ":/checkpoints"); + + c.add(image); + + c.add("python"); + c.add("/workspace/change-detection-code/run_evaluation_pipeline.py"); + + c.add("--dataset_dir"); + c.add("/data/" + uuid); + + c.add("--model"); + c.add("/checkpoints/" + uuid + "/" + modelFile); + + return c; + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java new file mode 100644 index 0000000..e90cf1c --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java @@ -0,0 +1,76 @@ +package com.kamco.cd.training.train.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; +import java.time.ZonedDateTime; +import java.util.Map; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional(readOnly = true) +public class TestJobService { + private final ModelTrainJobCoreService modelTrainJobCoreService; + private final ModelTrainMngCoreService modelTrainMngCoreService; + private final DockerTrainService dockerTrainService; + private final ObjectMapper objectMapper; + private final ApplicationEventPublisher eventPublisher; + + @Transactional + public Long enqueue(Long modelId, UUID uuid, int epoch) { + + // 마스터 확인 + modelTrainMngCoreService.findModelById(modelId); + + Map params = new java.util.LinkedHashMap<>(); + params.put("jobType", "EVAL"); + params.put("uuid", uuid); + params.put("epoch", epoch); + + int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; + + Long jobId = + modelTrainJobCoreService.createQueuedJob( + modelId, nextAttemptNo, params, ZonedDateTime.now()); + + // step2 시작으로 마킹 + modelTrainMngCoreService.markStep2InProgress(modelId, jobId); + + eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); + return jobId; + } + + @Transactional + public void cancel(Long modelId) { + + ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); + + Long jobId = master.getCurrentAttemptId(); + if (jobId == null) { + throw new IllegalStateException("실행중인 작업이 없습니다."); + } + + var job = + modelTrainJobCoreService + .findById(jobId) + .orElseThrow(() -> new IllegalStateException("Job not found")); + + String containerName = job.getContainerName(); + + // 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게) + if (containerName != null && !containerName.isBlank()) { + dockerTrainService.killContainer(containerName); + } + + // 2) 상태 업데이트 (항상 수행) + modelTrainJobCoreService.markCanceled(jobId); + modelTrainMngCoreService.markStopped(modelId); + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index b5e1bd2..e6f57b6 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -7,10 +7,14 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.TrainRunRequest; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.time.ZonedDateTime; import java.util.Map; import java.util.UUID; import lombok.RequiredArgsConstructor; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.ApplicationEventPublisher; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -22,9 +26,14 @@ public class TrainJobService { private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService; + private final DockerTrainService dockerTrainService; private final ObjectMapper objectMapper; private final ApplicationEventPublisher eventPublisher; + // 학습 결과가 저장될 호스트 디렉토리 + @Value("${train.docker.responseDir}") + private String responseDir; + public Long getModelIdByUuid(UUID uuid) { return modelTrainMngCoreService.findModelIdByUuid(uuid); } @@ -36,6 +45,7 @@ public class TrainJobService { // 마스터 존재 확인(없으면 예외) modelTrainMngCoreService.findModelById(modelId); + // 파라미터 조회 TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId); if (trainRunRequest == null) { @@ -46,6 +56,7 @@ public class TrainJobService { @SuppressWarnings("unchecked") Map paramsMap = objectMapper.convertValue(trainRunRequest, Map.class); + paramsMap.put("jobType", "TRAIN"); Long jobId = modelTrainJobCoreService.createQueuedJob( @@ -57,16 +68,66 @@ public class TrainJobService { // 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함) eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); + modelTrainMngCoreService.markStep1InProgress(modelId, jobId); return jobId; } /** - * 재시작 버튼 + * 재시작 * *

- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성 */ @Transactional public Long restart(Long modelId) { + return createNextAttempt(modelId, ResumeMode.NONE); + } + + /** + * 이어하기 + * + * @param modelId + * @return + */ + @Transactional + public Long resume(Long modelId) { + return createNextAttempt(modelId, ResumeMode.REQUIRE); + } + + /** + * 중단 + * + *

- job 상태 CANCELED - master 상태 STOPPED + * + *

※ 실제 docker stop은 Worker/Runner가 수행(운영 안정) + */ + @Transactional + public void cancel(Long modelId) { + + ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); + + Long jobId = master.getCurrentAttemptId(); + if (jobId == null) { + throw new IllegalStateException("실행중인 작업이 없습니다."); + } + + var job = + modelTrainJobCoreService + .findById(jobId) + .orElseThrow(() -> new IllegalStateException("Job not found")); + + String containerName = job.getContainerName(); + + // 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게) + if (containerName != null && !containerName.isBlank()) { + dockerTrainService.killContainer(containerName); + } + + // 2) 상태 업데이트 (항상 수행) + modelTrainJobCoreService.markCanceled(jobId); + modelTrainMngCoreService.markStopped(modelId); + } + + private Long createNextAttempt(Long modelId, ResumeMode mode) { ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); @@ -81,39 +142,72 @@ public class TrainJobService { int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; + // 이전 params_json 재사용 (재현성) + Map params = lastJob.getParamsJson(); + if (params == null || params.isEmpty()) { + throw new IllegalStateException("이전 실행 params_json이 없습니다."); + } + + // mode에 따라 resume 옵션 주입/제거 + Map nextParams = new java.util.LinkedHashMap<>(params); + + if (mode == ResumeMode.NONE) { + // 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장) + nextParams.remove("resumeFrom"); + nextParams.remove("resume"); + } else if (mode == ResumeMode.REQUIRE) { + // 체크포인트 탐지해서 resumeFrom 세팅 + String resumeFrom = findResumeFromOrNull(nextParams); + if (resumeFrom == null) { + throw new IllegalStateException("이어하기 체크포인트가 없습니다."); + } + nextParams.put("resumeFrom", resumeFrom); + nextParams.put("resume", true); + } + Long jobId = modelTrainJobCoreService.createQueuedJob( - modelId, - nextAttemptNo, - lastJob.getParamsJson(), // Map 그대로 재사용 - ZonedDateTime.now()); + modelId, nextAttemptNo, nextParams, ZonedDateTime.now()); modelTrainMngCoreService.clearLastError(modelId); modelTrainMngCoreService.markInProgress(modelId, jobId); eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); - return jobId; } - /** - * 중단 버튼 - * - *

- job 상태 CANCELED - master 상태 STOPPED - * - *

※ 실제 docker stop은 Worker/Runner가 수행(운영 안정) - */ - @Transactional - public void cancel(Long modelId) { + private enum ResumeMode { + NONE, // 새로 시작 + REQUIRE // 이어하기 + } - ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); + public String findResumeFromOrNull(Map paramsJson) { + if (paramsJson == null) return null; - Long attemptId = master.getCurrentAttemptId(); - if (attemptId == null) { - throw new IllegalStateException("실행중인 작업이 없습니다."); + Object out = paramsJson.get("outputFolder"); + if (out == null) return null; + + String outputFolder = String.valueOf(out).trim(); // uuid + if (outputFolder.isEmpty()) return null; + + // 호스트 기준 경로 + Path outDir = Paths.get(responseDir, outputFolder); + + Path last = outDir.resolve("last_checkpoint"); + if (!Files.isRegularFile(last)) return null; + + try { + String ckptFile = Files.readString(last).trim(); // epoch_10.pth + if (ckptFile.isEmpty()) return null; + + Path ckptHost = outDir.resolve(ckptFile); + if (!Files.isRegularFile(ckptHost)) return null; + + // 컨테이너 경로 반환 + return "/checkpoints/" + outputFolder + "/" + ckptFile; + + } catch (Exception e) { + return null; } - - modelTrainJobCoreService.markCanceled(attemptId); - modelTrainMngCoreService.markStopped(modelId); } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 8acb348..4eedc9a 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -1,9 +1,11 @@ package com.kamco.cd.training.train.service; import com.fasterxml.jackson.databind.ObjectMapper; +import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; -import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import com.kamco.cd.training.train.dto.EvalRunRequest; +import com.kamco.cd.training.train.dto.ModelTrainJobDto; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunResult; @@ -27,53 +29,80 @@ public class TrainJobWorker { @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) public void handle(ModelTrainJobQueuedEvent event) { - Long jobId = event.getJobId(); // record면 event.jobId() + Long jobId = event.getJobId(); - ModelTrainJobEntity job = + ModelTrainJobDto job = modelTrainJobCoreService .findById(jobId) .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); - Long modelId = job.getModelId(); - - // enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적 - TrainRunRequest req = toTrainRunRequest(job.getParamsJson()); - // req가 null이면 실패 처리 - if (req == null) { - modelTrainJobCoreService.markFailed( - jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)"); - modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패"); + if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) { return; } - // 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리 - String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨 + Long modelId = job.getModelId(); + Map params = job.getParamsJson(); - // logPath/lockedBy는 너 환경에 맞게 - String logPath = null; - String lockedBy = "TRAIN_WORKER"; + String jobType = params != null ? String.valueOf(params.get("jobType")) : null; - // RUNNING 표시 - modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy); + boolean isEval = "EVAL".equals(jobType); + + String containerName = (isEval ? "eval-" : "train-") + jobId; + + modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER"); try { - // DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서 - // markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음. - // 아래 "추천 수정" 참고. - TrainRunResult result = dockerTrainService.runTrainSync(req); + TrainRunResult result; + + if (isEval) { + String uuid = String.valueOf(params.get("uuid")); + int epoch = (int) params.get("epoch"); + + EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null); + result = dockerTrainService.runEvalSync(evalReq, containerName); + + } else { + TrainRunRequest trainReq = toTrainRunRequest(params); + result = dockerTrainService.runTrainSync(trainReq, containerName); + } + + ModelTrainJobDto latest = + modelTrainJobCoreService + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) { + return; + } if (result.getExitCode() == 0) { modelTrainJobCoreService.markSuccess(jobId, result.getExitCode()); - modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰 + + if (isEval) { + modelTrainMngCoreService.markStep2Success(modelId); + } else { + modelTrainMngCoreService.markStep1Success(modelId); + } + } else { modelTrainJobCoreService.markFailed( jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); - modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode()); + + if (isEval) { + modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode()); + } else { + modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode()); + } } } catch (Exception e) { modelTrainJobCoreService.markFailed(jobId, null, e.toString()); - modelTrainMngCoreService.markError(modelId, e.getMessage()); + + if ("EVAL".equals(params.get("jobType"))) { + modelTrainMngCoreService.markStep2Error(modelId, e.getMessage()); + } else { + modelTrainMngCoreService.markError(modelId, e.getMessage()); + } } }