From 13d9d9176b6bf35575e0c92b345b04074c950e72 Mon Sep 17 00:00:00 2001 From: teddy Date: Mon, 6 Apr 2026 20:40:42 +0900 Subject: [PATCH] =?UTF-8?q?=EC=83=81=ED=83=9C=EB=B3=80=EA=B2=BD=20?= =?UTF-8?q?=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/ModelTestMetricsJobCoreService.java | 4 + .../core/ModelTrainJobCoreService.java | 11 + .../core/ModelTrainMetricsJobCoreService.java | 4 + .../ModelTestMetricsJobRepositoryCustom.java | 2 + .../ModelTestMetricsJobRepositoryImpl.java | 19 + .../ModelTrainMetricsJobRepositoryCustom.java | 2 + .../ModelTrainMetricsJobRepositoryImpl.java | 19 + .../cd/training/train/TrainApiController.java | 23 ++ .../train/dto/DockerInspectState.java | 8 + .../cd/training/train/dto/OutputResult.java | 20 ++ .../service/JobRecoveryOnStartupService.java | 333 +----------------- .../service/ModelTestMetricsJobService.java | 200 ++++++----- .../service/ModelTrainMetricsJobService.java | 232 ++++++------ .../train/service/TrainJobService.java | 137 +++++++ .../train/service/TrainUtilService.java | 228 ++++++++++++ 15 files changed, 718 insertions(+), 524 deletions(-) create mode 100644 src/main/java/com/kamco/cd/training/train/dto/DockerInspectState.java create mode 100644 src/main/java/com/kamco/cd/training/train/dto/OutputResult.java create mode 100644 src/main/java/com/kamco/cd/training/train/service/TrainUtilService.java diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java index a8d7b2c..3be3328 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java @@ -26,6 +26,10 @@ public class ModelTestMetricsJobCoreService { return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds(); } + public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) { + return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelId(modelId); + } + public void insertModelMetricsTest(List batchArgs) { modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java index 8cc1fe8..87321ff 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -80,6 +80,17 @@ public class ModelTrainJobCoreService { } } + /** 실행 시작 처리 수정 */ + @Transactional + public void updateJobStatus(Long jobId, String jobStatus) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findById(jobId) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + job.setStatusCd(jobStatus); + } + /** * 성공 처리 * diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java index 3f592d8..25c29cf 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java @@ -17,6 +17,10 @@ public class ModelTrainMetricsJobCoreService { return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds(); } + public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) { + return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelId(modelId); + } + public void insertModelMetricsTrain(List batchArgs) { modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java index 4c55743..fd7a25f 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java @@ -12,6 +12,8 @@ public interface ModelTestMetricsJobRepositoryCustom { List getTestMetricSaveNotYetModelIds(); + ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId); + void insertModelMetricsTest(List batchArgs); ModelMetricJsonDto getTestMetricPackingInfo(Long modelId); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java index fa28fb1..687dbfb 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java @@ -63,6 +63,25 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport .fetch(); } + @Override + public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) { + return queryFactory + .select( + Projections.constructor( + ResponsePathDto.class, + modelMasterEntity.id, + modelMasterEntity.responsePath, + modelMasterEntity.uuid)) + .from(modelMasterEntity) + .where( + modelMasterEntity.id.eq(modelId), + modelMasterEntity + .step2MetricSaveYn + .isNull() + .or(modelMasterEntity.step2MetricSaveYn.isFalse())) + .fetchOne(); + } + @Override public void insertModelMetricsTest(List batchArgs) { // AS-IS diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java index f4031bf..22cd2dc 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java @@ -7,6 +7,8 @@ public interface ModelTrainMetricsJobRepositoryCustom { List getTrainMetricSaveNotYetModelIds(); + ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId); + void insertModelMetricsTrain(List batchArgs); void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java index 1323d40..d36586d 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java @@ -44,6 +44,25 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor .fetch(); } + @Override + public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) { + return queryFactory + .select( + Projections.constructor( + ResponsePathDto.class, + modelMasterEntity.id, + modelMasterEntity.responsePath, + modelMasterEntity.uuid)) + .from(modelMasterEntity) + .where( + modelMasterEntity.id.eq(modelId), + modelMasterEntity + .step1MetricSaveYn + .isNull() + .or(modelMasterEntity.step1MetricSaveYn.isFalse())) + .fetchOne(); + } + @Override public void insertModelMetricsTrain(List batchArgs) { String sql = diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java index 1c5f9d4..e500a90 100644 --- a/src/main/java/com/kamco/cd/training/train/TrainApiController.java +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -213,4 +213,27 @@ public class TrainApiController { Long modelId = trainJobService.getModelIdByUuid(uuid); return ApiResponseDto.ok(dataSetCountersService.getCount(modelId)); } + + @Operation(summary = "학습 상태 확인", description = "학습 상태 확인") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "학습 상태 변경", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping(path = "/status/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE) + public ApiResponseDto status( + @Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + trainJobService.status(uuid, modelId); + return ApiResponseDto.ok("ok"); + } } diff --git a/src/main/java/com/kamco/cd/training/train/dto/DockerInspectState.java b/src/main/java/com/kamco/cd/training/train/dto/DockerInspectState.java new file mode 100644 index 0000000..fd93dad --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/DockerInspectState.java @@ -0,0 +1,8 @@ +package com.kamco.cd.training.train.dto; + +public record DockerInspectState(boolean exists, boolean running, Integer exitCode, String status) { + + public static DockerInspectState missing() { + return new DockerInspectState(false, false, null, "missing"); + } +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/OutputResult.java b/src/main/java/com/kamco/cd/training/train/dto/OutputResult.java new file mode 100644 index 0000000..23742ae --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/OutputResult.java @@ -0,0 +1,20 @@ +package com.kamco.cd.training.train.dto; + +public class OutputResult { + + private final boolean completed; + private final String reason; + + public OutputResult(boolean completed, String reason) { + this.completed = completed; + this.reason = reason; + } + + public boolean completed() { + return completed; + } + + public String reason() { + return reason; + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java b/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java index ffff705..395a33b 100644 --- a/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java +++ b/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java @@ -1,25 +1,16 @@ package com.kamco.cd.training.train.service; -import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.dto.DockerInspectState; import com.kamco.cd.training.train.dto.ModelTrainJobDto; -import java.io.BufferedReader; +import com.kamco.cd.training.train.dto.OutputResult; import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; -import java.nio.file.DirectoryStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.TimeUnit; -import java.util.stream.Stream; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; -import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.annotation.Profile; import org.springframework.context.event.EventListener; @@ -44,14 +35,7 @@ public class JobRecoveryOnStartupService { private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMetricsJobService modelTrainMetricsJobService; - - /** - * Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response - * - *

컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다. - */ - @Value("${train.docker.response_dir}") - private String responseDir; + private final TrainUtilService trainUtilService; /** * 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행. @@ -77,7 +61,7 @@ public class JobRecoveryOnStartupService { try { // 2-1) docker inspect로 컨테이너 상태 조회 - DockerInspectState state = inspectContainer(containerName); + DockerInspectState state = trainUtilService.inspectContainer(containerName); // 3) 컨테이너가 "없음" // - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음 @@ -88,7 +72,7 @@ public class JobRecoveryOnStartupService { containerName); // 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정" - OutputResult out = probeOutputs(job); + OutputResult out = trainUtilService.probeOutputs(job); // 3-2) 산출물이 충분하면 성공 처리 if (out.completed()) { @@ -109,11 +93,9 @@ public class JobRecoveryOnStartupService { job.getId(), out.reason()); - Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId()); - // PAUSED/STOP modelTrainJobCoreService.markPaused( - job.getId(), modelId, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE"); + job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE"); // 모델도 에러가 아니라 STOP으로 markStepStopByJobType( @@ -152,7 +134,7 @@ public class JobRecoveryOnStartupService { // ============================================================ // 2) kill 후 실제로 죽었는지 확인 // ============================================================ - DockerInspectState after = inspectContainer(containerName); + DockerInspectState after = trainUtilService.inspectContainer(containerName); if (after.exists() && after.running()) { throw new IOException("docker kill returned 0 but container still running"); } @@ -162,10 +144,8 @@ public class JobRecoveryOnStartupService { // ============================================================ // 3) job 상태를 PAUSED로 변경 (서버 재기동으로 강제 중단) // ============================================================ - Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId()); - modelTrainJobCoreService.markPaused( - job.getId(), modelId, "AUTO_KILLED_ON_SERVER_RESTART"); + modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_KILLED_ON_SERVER_RESTART"); log.info("job = {}", job); markStepStopByJobType(job, "AUTO_KILLED_ON_SERVER_RESTART"); @@ -264,301 +244,4 @@ public class JobRecoveryOnStartupService { modelTrainMngCoreService.markError(job.getModelId(), msg); } } - - /** - * docker inspect를 사용해서 컨테이너 상태를 조회합니다. - * - *

사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}} - * - *

예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137" - * - *

주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초 - * 타임아웃을 둠 - */ - private DockerInspectState inspectContainer(String containerName) - throws IOException, InterruptedException { - - ProcessBuilder pb = - new ProcessBuilder( - "docker", - "inspect", - "-f", - "{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}", - containerName); - - // stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음) - pb.redirectErrorStream(true); - - Process p = pb.start(); - - // inspect 출력은 1줄이면 충분하므로 readLine()만 수행 - String output; - try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) { - output = br.readLine(); - } - - // 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료 - boolean finished = p.waitFor(5, TimeUnit.SECONDS); - if (!finished) { - p.destroyForcibly(); - throw new IOException("docker inspect timeout"); - } - - // docker inspect 자체의 프로세스 exit code - int code = p.exitValue(); - - // 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주 - if (code != 0 || output == null || output.isBlank()) { - return DockerInspectState.missing(); - } - - // "status running exitCode" 형태로 split - String[] parts = output.trim().split("\\s+"); - - // status: running/exited/dead 등 - String status = parts.length > 0 ? parts[0] : "unknown"; - - // running: true/false - boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]); - - // exitCode: 정수 파싱(파싱 실패하면 null) - Integer exitCode = null; - if (parts.length > 2) { - try { - exitCode = Integer.parseInt(parts[2]); - } catch (Exception ignore) { - // ignore - } - } - - return new DockerInspectState(true, running, exitCode, status); - } - - /** - * docker inspect 결과를 담는 레코드. - * - *

exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주) - */ - private record DockerInspectState( - boolean exists, boolean running, Integer exitCode, String status) { - static DockerInspectState missing() { - return new DockerInspectState(false, false, null, "missing"); - } - } - - // ============================================================================================ - // 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직 - // ============================================================================================ - - /** - * 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다. - * - *

판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수 - * >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함 - * - *

왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를 - * 실패로 오판할 수 있음. - */ - private OutputResult probeOutputs(ModelTrainJobDto job) { - try { - - log.info( - "[RECOVERY] probeOutputs start. jobId={}, modelId={}", job.getId(), job.getModelId()); - - // 1) 출력 디렉토리 확인 - Path outDir = resolveOutputDir(job); - if (outDir == null || !Files.isDirectory(outDir)) { - log.warn("[RECOVERY] output directory missing. jobId={}, path={}", job.getId(), outDir); - return new OutputResult(false, "output-dir-missing"); - } - - log.info("[RECOVERY] output directory found. jobId={}, path={}", job.getId(), outDir); - - // 2) totalEpoch 확인 - Integer totalEpoch = extractTotalEpoch(job).orElse(null); - if (totalEpoch == null || totalEpoch <= 0) { - log.warn( - "[RECOVERY] totalEpoch missing or invalid. jobId={}, totalEpoch={}", - job.getId(), - totalEpoch); - return new OutputResult(false, "total-epoch-missing"); - } - - Integer valInterval = extractValInterval(job).orElse(null); - if (valInterval == null || valInterval <= 0) { - log.warn( - "[RECOVERY] valInterval missing or invalid. jobId={}, valInterval={}", - job.getId(), - valInterval); - return new OutputResult(false, "val-interval-missing"); - } - - log.info( - "[RECOVERY] totalEpoch={}. valInterval={}. jobId={}", - totalEpoch, - valInterval, - job.getId()); - - // 3) val.csv 존재 확인 - Path valCsv = outDir.resolve("val.csv"); - if (!Files.exists(valCsv)) { - log.warn("[RECOVERY] val.csv missing. jobId={}, path={}", job.getId(), valCsv); - return new OutputResult(false, "val.csv-missing"); - } - - // 4) val.csv 라인 수 확인 - long lines = countNonHeaderLines(valCsv); - - // expected = 실제 val 실행 횟수 - int expectedLines = totalEpoch / valInterval; - - log.info( - "[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}", - job.getId(), - lines, - expectedLines); - - // 5) 완료 판정 - if (lines >= expectedLines) { - log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId()); - return new OutputResult(true, "ok"); - } - - log.warn( - "[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}", - job.getId(), - lines, - expectedLines); - - return new OutputResult( - false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch); - - } catch (Exception e) { - - log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e); - - return new OutputResult(false, "probe-error"); - } - } - - /** - * responseDir 아래에서 job 산출물 디렉토리를 찾습니다. - * - *

가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다. - * - *

현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3) - * {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체 - * - *

추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId) - * 형태를 1순위로 두세요. - */ - private Path resolveOutputDir(ModelTrainJobDto job) { - ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId()); - - Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics"); - - return Files.isDirectory(base) ? base : null; - } - - /** - * paramsJson에서 total_epoch 값을 추출합니다. - * - *

키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase) - * - *

예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...} - */ - private Optional extractTotalEpoch(ModelTrainJobDto job) { - Map params = job.getParamsJson(); - if (params == null) return Optional.empty(); - - Object v = params.get("total_epoch"); - if (v == null) v = params.get("totalEpoch"); - if (v == null) return Optional.empty(); - - try { - return Optional.of(Integer.parseInt(String.valueOf(v))); - } catch (Exception ignore) { - return Optional.empty(); - } - } - - /** - * CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다. - * - *

가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록) - * - *

주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외 - */ - private long countNonHeaderLines(Path csv) throws IOException { - try (Stream lines = Files.lines(csv, StandardCharsets.UTF_8)) { - return lines.skip(1).filter(s -> s != null && !s.isBlank()).count(); - } - } - - /** - * 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다. - * - *

예: - "*.pth" - "best*.pth" - */ - private long countFilesByGlob(Path dir, String glob) throws IOException { - try (DirectoryStream ds = Files.newDirectoryStream(dir, glob)) { - long cnt = 0; - for (Path p : ds) { - if (Files.isRegularFile(p)) cnt++; - } - return cnt; - } - } - - /** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */ - private boolean existsByGlob(Path dir, String glob) throws IOException { - try (DirectoryStream ds = Files.newDirectoryStream(dir, glob)) { - return ds.iterator().hasNext(); - } - } - - // ============================================================================================ - // probeOutputs() 결과 객체 - // ============================================================================================ - - /** - * 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과. - * - *

completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단) - * - *

reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도) - */ - private static final class OutputResult { - - private final boolean completed; - private final String reason; - - private OutputResult(boolean completed, String reason) { - this.completed = completed; - this.reason = reason; - } - - boolean completed() { - return completed; - } - - String reason() { - return reason; - } - } - - /** paramsJson에서 valInterval 추출 */ - private Optional extractValInterval(ModelTrainJobDto job) { - Map params = job.getParamsJson(); - if (params == null) return Optional.empty(); - - Object v = params.get("valInterval"); - if (v == null) return Optional.empty(); - - try { - return Optional.of(Integer.parseInt(String.valueOf(v))); - } catch (Exception ignore) { - return Optional.empty(); - } - } } diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index 22c2d58..a0c1447 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -59,119 +59,133 @@ public class ModelTestMetricsJobService { } for (ResponsePathDto modelInfo : modelIds) { + createFile(modelInfo); + } + } - String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) { + /** 단건 결과 csv 파일 정보 등록 */ + public void testValidMetricCsvFiles(Long modelId) { - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + ResponsePathDto model = modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelId(modelId); - List batchArgs = new ArrayList<>(); + if (model == null) { + return; + } - for (CSVRecord record : parser) { + createFile(model); + } - String model = record.get("model"); - long TP = Long.parseLong(record.get("TP")); - long FP = Long.parseLong(record.get("FP")); - long FN = Long.parseLong(record.get("FN")); - float precision = Float.parseFloat(record.get("precision")); - float recall = Float.parseFloat(record.get("recall")); - float f1_score = Float.parseFloat(record.get("f1_score")); - float accuracy = Float.parseFloat(record.get("accuracy")); - float iou = Float.parseFloat(record.get("iou")); - long detection_count = Long.parseLong(record.get("detection_count")); - long gt_count = Long.parseLong(record.get("gt_count")); + /** + * 베스트 에폭 zip파일 생성, 테스트결과 db등록 + * + * @param modelInfo + */ + private void createFile(ResponsePathDto modelInfo) { - batchArgs.add( - new Object[] { - modelInfo.getModelId(), - model, - TP, - FP, - FN, - precision, - recall, - f1_score, - accuracy, - iou, - detection_count, - gt_count - }); - } + String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) { - modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - // test.csv 파일 읽어서 저장한 여부로만 사용하기 - modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn( - modelInfo.getModelId(), "step2"); + List batchArgs = new ArrayList<>(); - } catch (IOException e) { - throw new RuntimeException(e); + for (CSVRecord record : parser) { + + String model = record.get("model"); + long TP = Long.parseLong(record.get("TP")); + long FP = Long.parseLong(record.get("FP")); + long FN = Long.parseLong(record.get("FN")); + float precision = Float.parseFloat(record.get("precision")); + float recall = Float.parseFloat(record.get("recall")); + float f1_score = Float.parseFloat(record.get("f1_score")); + float accuracy = Float.parseFloat(record.get("accuracy")); + float iou = Float.parseFloat(record.get("iou")); + long detection_count = Long.parseLong(record.get("detection_count")); + long gt_count = Long.parseLong(record.get("gt_count")); + + batchArgs.add( + new Object[] { + modelInfo.getModelId(), + model, + TP, + FP, + FN, + precision, + recall, + f1_score, + accuracy, + iou, + detection_count, + gt_count + }); } - // 패키징할 파일 만들기 - modelTestMetricsJobCoreService.updatePackingStart( - modelInfo.getModelId(), ZonedDateTime.now()); + modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); - ModelMetricJsonDto jsonDto = - modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId()); - try { - writeJsonFile( - jsonDto, - Paths.get( - responseDir - + "/" - + modelInfo.getUuid() - + "/" - + jsonDto.getModelVersion() - + ".json")); - } catch (IOException e) { - throw new RuntimeException(e); - } + // test.csv 파일 읽어서 저장한 여부로만 사용하기 + modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2"); - Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); + } catch (IOException e) { + throw new RuntimeException(e); + } - ModelTestFileName fileInfo = - modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId()); + // 패키징할 파일 만들기 + modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now()); - Path zipPath = + ModelMetricJsonDto jsonDto = + modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId()); + try { + writeJsonFile( + jsonDto, Paths.get( - responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip"); - Set targetNames = - Set.of( - "model_config.py", - fileInfo.getBestEpochFileName() + ".pth", - fileInfo.getModelVersion() + ".json"); + responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json")); + } catch (IOException e) { + throw new RuntimeException(e); + } - List files = new ArrayList<>(); - try (Stream s = Files.list(responsePath)) { - files.addAll( - s.filter(Files::isRegularFile) - .filter(p -> targetNames.contains(p.getFileName().toString())) - .collect(Collectors.toList())); - } catch (IOException e) { - throw new RuntimeException(e); - } + Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); - try (Stream s = Files.list(Path.of(ptPathDir))) { - files.addAll( - s.filter(Files::isRegularFile) - .limit(1) // yolov8_6th-6m.pt 파일 1개만 - .collect(Collectors.toList())); - } catch (IOException e) { - throw new RuntimeException(e); - } + ModelTestFileName fileInfo = + modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId()); - try { - zipFiles(files, zipPath); + Path zipPath = + Paths.get( + responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip"); + Set targetNames = + Set.of( + "model_config.py", + fileInfo.getBestEpochFileName() + ".pth", + fileInfo.getModelVersion() + ".json"); - modelTestMetricsJobCoreService.updatePackingEnd( - modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId()); - } catch (IOException e) { - modelTestMetricsJobCoreService.updatePackingEnd( - modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId()); - throw new RuntimeException(e); - } + List files = new ArrayList<>(); + try (Stream s = Files.list(responsePath)) { + files.addAll( + s.filter(Files::isRegularFile) + .filter(p -> targetNames.contains(p.getFileName().toString())) + .collect(Collectors.toList())); + } catch (IOException e) { + throw new RuntimeException(e); + } + + try (Stream s = Files.list(Path.of(ptPathDir))) { + files.addAll( + s.filter(Files::isRegularFile) + .limit(1) // yolov8_6th-6m.pt 파일 1개만 + .collect(Collectors.toList())); + } catch (IOException e) { + throw new RuntimeException(e); + } + + try { + zipFiles(files, zipPath); + + modelTestMetricsJobCoreService.updatePackingEnd( + modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId()); + } catch (IOException e) { + modelTestMetricsJobCoreService.updatePackingEnd( + modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId()); + throw new RuntimeException(e); } } diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index 6566cf7..3a5181d 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -48,115 +48,135 @@ public class ModelTrainMetricsJobService { for (ResponsePathDto modelInfo : modelIds) { - String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) { - - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - - List batchArgs = new ArrayList<>(); - - for (CSVRecord record : parser) { - - int epoch = Integer.parseInt(record.get("Epoch")); - long iteration = Long.parseLong(record.get("Iteration")); - double Loss = Double.parseDouble(record.get("Loss")); - double LR = Double.parseDouble(record.get("LR")); - float time = Float.parseFloat(record.get("Time")); - - batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time}); - } - - modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); - - } catch (IOException e) { - throw new RuntimeException(e); - } - - String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { - - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - - List batchArgs = new ArrayList<>(); - - for (CSVRecord record : parser) { - - int epoch = Integer.parseInt(record.get("Epoch")); - - Float aAcc = parseFloatSafe(record.get("aAcc")); - Float mFscore = parseFloatSafe(record.get("mFscore")); - Float mPrecision = parseFloatSafe(record.get("mPrecision")); - Float mRecall = parseFloatSafe(record.get("mRecall")); - Float mIoU = parseFloatSafe(record.get("mIoU")); - Float mAcc = parseFloatSafe(record.get("mAcc")); - - Float changed_fscore = parseFloatSafe(record.get("changed_fscore")); - Float changed_precision = parseFloatSafe(record.get("changed_precision")); - Float changed_recall = parseFloatSafe(record.get("changed_recall")); - - Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore")); - Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision")); - Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall")); - - batchArgs.add( - new Object[] { - modelInfo.getModelId(), - epoch, - aAcc, - mFscore, - mPrecision, - mRecall, - mIoU, - mAcc, - changed_fscore, - changed_precision, - changed_recall, - unchanged_fscore, - unchanged_precision, - unchanged_recall - }); - } - - modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); - - } catch (IOException e) { - throw new RuntimeException(e); - } - - Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); - Integer epoch = null; - boolean exists; - Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth"); - - try (Stream s = Files.list(responsePath)) { - epoch = - s.filter(Files::isRegularFile) - .map( - p -> { - Matcher matcher = pattern.matcher(p.getFileName().toString()); - if (matcher.matches()) { - return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출 - } - return null; - }) - .filter(Objects::nonNull) - .findFirst() - .orElse(null); - - } catch (IOException e) { - throw new RuntimeException(e); - } - - // best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기 - modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch); - - modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn( - modelInfo.getModelId(), "step1"); + createFile(modelInfo); } } + /** 단건 결과 csv 파일 정보 등록 */ + public void trainValidMetricCsvFile(Long modelId) { + + ResponsePathDto modelInfo = + modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelId(modelId); + + if (modelInfo == null) { + return; + } + createFile(modelInfo); + } + + /** + * 학습 csv 파일 db 등록 + * + * @param modelInfo + */ + private void createFile(ResponsePathDto modelInfo) { + String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) { + + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + + List batchArgs = new ArrayList<>(); + + for (CSVRecord record : parser) { + + int epoch = Integer.parseInt(record.get("Epoch")); + long iteration = Long.parseLong(record.get("Iteration")); + double Loss = Double.parseDouble(record.get("Loss")); + double LR = Double.parseDouble(record.get("LR")); + float time = Float.parseFloat(record.get("Time")); + + batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time}); + } + + modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { + + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + + List batchArgs = new ArrayList<>(); + + for (CSVRecord record : parser) { + + int epoch = Integer.parseInt(record.get("Epoch")); + + Float aAcc = parseFloatSafe(record.get("aAcc")); + Float mFscore = parseFloatSafe(record.get("mFscore")); + Float mPrecision = parseFloatSafe(record.get("mPrecision")); + Float mRecall = parseFloatSafe(record.get("mRecall")); + Float mIoU = parseFloatSafe(record.get("mIoU")); + Float mAcc = parseFloatSafe(record.get("mAcc")); + + Float changed_fscore = parseFloatSafe(record.get("changed_fscore")); + Float changed_precision = parseFloatSafe(record.get("changed_precision")); + Float changed_recall = parseFloatSafe(record.get("changed_recall")); + + Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore")); + Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision")); + Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall")); + + batchArgs.add( + new Object[] { + modelInfo.getModelId(), + epoch, + aAcc, + mFscore, + mPrecision, + mRecall, + mIoU, + mAcc, + changed_fscore, + changed_precision, + changed_recall, + unchanged_fscore, + unchanged_precision, + unchanged_recall + }); + } + + modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); + Integer epoch = null; + boolean exists; + Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth"); + + try (Stream s = Files.list(responsePath)) { + epoch = + s.filter(Files::isRegularFile) + .map( + p -> { + Matcher matcher = pattern.matcher(p.getFileName().toString()); + if (matcher.matches()) { + return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출 + } + return null; + }) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + // best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기 + modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch); + + modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step1"); + } + private Float parseFloatSafe(String value) { try { if (value == null) return null; diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index 020f265..697bde9 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -1,13 +1,17 @@ package com.kamco.cd.training.train.service; import com.fasterxml.jackson.databind.ObjectMapper; +import com.kamco.cd.training.common.enums.JobStatusType; +import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.dto.DockerInspectState; import com.kamco.cd.training.train.dto.ModelTrainJobDto; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.ModelTrainLinkDto; +import com.kamco.cd.training.train.dto.OutputResult; import com.kamco.cd.training.train.dto.TrainRunRequest; import java.io.IOException; import java.nio.file.Files; @@ -16,6 +20,7 @@ import java.nio.file.Paths; import java.time.ZonedDateTime; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.UUID; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -33,11 +38,14 @@ public class TrainJobService { private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService; + private final ModelTrainMetricsJobService modelTrainMetricsJobService; + private final ModelTestMetricsJobService modelTestMetricsJobService; private final DockerTrainService dockerTrainService; private final ObjectMapper objectMapper; private final ApplicationEventPublisher eventPublisher; private final TmpDatasetService tmpDatasetService; private final DataSetCountersService dataSetCounters; + private final TrainUtilService trainUtilService; // 학습 결과가 저장될 호스트 디렉토리 @Value("${train.docker.response_dir}") @@ -309,4 +317,133 @@ public class TrainJobService { } return modelUuid; } + + /** + * 동작되고잇는 pid을 확인하고 pid 목록에 있으면 진행중, 없으면 종료(결과 csv, zip 파일 확인하여 완료 여부 체크) + * + * @param uuid + * @param modelId + */ + public void status(UUID uuid, Long modelId) { + + ModelTrainJobDto job = + modelTrainJobCoreService + .findLatestByModelId(modelId) + .orElseThrow(() -> new NoSuchElementException("job not found")); + + ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(modelId); + + // TODO 실행중 상태인것만 변경해야하면 주석 해제 + // if(job.getStatusCd().equals(JobStatusType.RUNNING.getId())) { + // return; + // } + + String containerName = job.getContainerName(); + + try { + // docker inspect로 컨테이너 상태 조회 + DockerInspectState state = trainUtilService.inspectContainer(containerName); + + // 컨테이너가 "없음" + // - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음 + // - 즉 "컨테이너 없음"이 무조건 실패는 아님 + if (!state.exists()) { + log.warn("container missing. try file-based reconcile. container={}", containerName); + + // 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정" + OutputResult out = trainUtilService.probeOutputs(job); + + // 산출물이 충분하면 성공 처리 + if (out.completed()) { + + // 테스트 완료인지 zip파일로 확인 + if (trainUtilService.existsZipFile(uuid)) { + // 테스트 완료일때 + log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId()); + + // job 완료처리 + modelTrainJobCoreService.markSuccess(job.getId(), 0); + + // 학습 완료가 아니면 완료 업데이트 + if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) { + + // model 상태 변경 (학습) + modelTrainMngCoreService.markStep1Success(job.getModelId()); + // 학습 결과 csv 파일 정보 등록 + modelTrainMetricsJobService.trainValidMetricCsvFile(modelId); + } + + // 테스트 완료가 아니면 완료 업데이트 + if (!model.getStep2Status().equals(TrainStatusType.COMPLETED.getId())) { + // model 상태 변경 (테스트) + modelTrainMngCoreService.markStep2Success(job.getModelId()); + // 테스트 결과 csv 파일 정보 등록 + modelTestMetricsJobService.testValidMetricCsvFiles(modelId); + } + + } else { + // 학습 완료일때 + log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId()); + modelTrainJobCoreService.markSuccess(job.getId(), 0); + // 학습 완료가 아니면 완료 업데이트 + if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) { + + // model 상태 변경 (학습) + modelTrainMngCoreService.markStep1Success(job.getModelId()); + // 학습 결과 csv 파일 정보 등록 + modelTrainMetricsJobService.trainValidMetricCsvFile(modelId); + } + } + + } else { + + // 산출물이 부족하면 중단처리 + // 산출물이 부족하면 "중단/보류"로 처리 + // 운영자가 재시작 할 수 있게 한다. + log.warn( + "outputs incomplete. mark PAUSED/STOP for restart. jobId={} reason={}", + job.getId(), + out.reason()); + + // PAUSED/STOP + modelTrainJobCoreService.markPaused( + job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE"); + + // STOP으로 변경 + markStepStopByJobType( + job, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE: " + out.reason()); + } + } else { + // 컨테이너가 있으면 진행중처리 + Map params = job.getParamsJson(); + boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType"))); + if (isEval) { + // 테스트 진행중 상태로 업데이트 + modelTrainMngCoreService.markStep2InProgress(job.getModelId(), job.getId()); + } else { + // 학습 진행중 상태로 업데이트 + modelTrainMngCoreService.markStep1InProgress(job.getModelId(), job.getId()); + } + // job 테이블 진행중으로 업데이트 + modelTrainJobCoreService.updateJobStatus(job.getId(), JobStatusType.RUNNING.getId()); + } + } catch (Exception e) { + log.error("container inspect failed. container={}", containerName, e); + } + } + + /** + * jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트. + * + *

예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러 + */ + private void markStepStopByJobType(ModelTrainJobDto job, String msg) { + Map params = job.getParamsJson(); + boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType"))); + if (isEval) { + modelTrainMngCoreService.markStep2Stop(job.getModelId(), msg); + } else { + modelTrainMngCoreService.markStep1Stop(job.getModelId(), msg); + } + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainUtilService.java b/src/main/java/com/kamco/cd/training/train/service/TrainUtilService.java new file mode 100644 index 0000000..2be678c --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/TrainUtilService.java @@ -0,0 +1,228 @@ +package com.kamco.cd.training.train.service; + +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.dto.DockerInspectState; +import com.kamco.cd.training.train.dto.ModelTrainJobDto; +import com.kamco.cd.training.train.dto.OutputResult; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.*; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; +import lombok.RequiredArgsConstructor; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +@Service +@RequiredArgsConstructor +public class TrainUtilService { + + private final ModelTrainMngCoreService modelTrainMngCoreService; + + /** + * Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response + * + *

컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다. + */ + @Value("${train.docker.response_dir}") + private String responseDir; + + /** + * docker inspect를 사용해서 컨테이너 상태를 조회합니다. + * + *

사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}} + * + *

예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137" + * + *

주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초 + * 타임아웃을 둠 + */ + public DockerInspectState inspectContainer(String containerName) + throws IOException, InterruptedException { + + ProcessBuilder pb = + new ProcessBuilder( + "docker", + "inspect", + "-f", + "{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}", + containerName); + + pb.redirectErrorStream(true); + + Process p = pb.start(); + + String output; + try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) { + output = br.readLine(); + } + + boolean finished = p.waitFor(5, TimeUnit.SECONDS); + if (!finished) { + p.destroyForcibly(); + throw new IOException("docker inspect timeout"); + } + + int code = p.exitValue(); + + if (code != 0 || output == null || output.isBlank()) { + return DockerInspectState.missing(); + } + + String[] parts = output.trim().split("\\s+"); + + String status = parts.length > 0 ? parts[0] : "unknown"; + boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]); + + Integer exitCode = null; + if (parts.length > 2) { + try { + exitCode = Integer.parseInt(parts[2]); + } catch (Exception ignore) { + } + } + + return new DockerInspectState(true, running, exitCode, status); + } + + /** + * 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다. + * + *

판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수 + * >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함 + * + *

왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를 + * 실패로 오판할 수 있음. + */ + public OutputResult probeOutputs(ModelTrainJobDto job) { + + try { + Path outDir = resolveOutputDir(job); + if (outDir == null || !Files.isDirectory(outDir)) { + return new OutputResult(false, "output-dir-missing"); + } + + Integer totalEpoch = extractTotalEpoch(job).orElse(null); + if (totalEpoch == null || totalEpoch <= 0) { + return new OutputResult(false, "total-epoch-missing"); + } + + Integer valInterval = extractValInterval(job).orElse(null); + if (valInterval == null || valInterval <= 0) { + return new OutputResult(false, "val-interval-missing"); + } + + Path valCsv = outDir.resolve("val.csv"); + if (!Files.exists(valCsv)) { + return new OutputResult(false, "val.csv-missing"); + } + + long lines = countNonHeaderLines(valCsv); + int expectedLines = totalEpoch / valInterval; + + if (lines >= expectedLines) { + return new OutputResult(true, "ok"); + } + + return new OutputResult(false, "val.csv-lines-mismatch"); + + } catch (Exception e) { + return new OutputResult(false, "probe-error"); + } + } + + /** + * 테스트 완료후 zip 파일 있는지 확인 + * + * @param uuid + * @return + */ + public boolean existsZipFile(UUID uuid) { + Path path = Paths.get(responseDir, uuid.toString()); + + if (!Files.isDirectory(path)) { + return false; + } + + String pattern = "*" + uuid + "*.zip"; + + try (DirectoryStream stream = Files.newDirectoryStream(path, pattern)) { + return stream.iterator().hasNext(); + } catch (IOException e) { + return false; + } + } + + /** + * responseDir 아래에서 job 산출물 디렉토리를 찾습니다. + * + *

가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다. + * + *

현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3) + * {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체 + * + *

추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId) + * 형태를 1순위로 두세요. + */ + private Path resolveOutputDir(ModelTrainJobDto job) { + ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId()); + + Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics"); + + return Files.isDirectory(base) ? base : null; + } + + /** + * paramsJson에서 total_epoch 값을 추출합니다. + * + *

키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase) + * + *

예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...} + */ + private Optional extractTotalEpoch(ModelTrainJobDto job) { + Map params = job.getParamsJson(); + if (params == null) return Optional.empty(); + + Object v = params.get("total_epoch"); + if (v == null) v = params.get("totalEpoch"); + + try { + return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v))); + } catch (Exception e) { + return Optional.empty(); + } + } + + /** paramsJson에서 valInterval 추출 */ + private Optional extractValInterval(ModelTrainJobDto job) { + Map params = job.getParamsJson(); + if (params == null) return Optional.empty(); + + Object v = params.get("valInterval"); + + try { + return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v))); + } catch (Exception e) { + return Optional.empty(); + } + } + + /** + * CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다. + * + *

가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록) + * + *

주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외 + */ + private long countNonHeaderLines(Path csv) throws IOException { + try (Stream lines = Files.lines(csv, StandardCharsets.UTF_8)) { + return lines.skip(1).filter(s -> s != null && !s.isBlank()).count(); + } + } +}