From 8f75b16dc6df1a18bd5f2807f030eec745c831d5 Mon Sep 17 00:00:00 2001 From: teddy Date: Mon, 23 Feb 2026 12:30:54 +0900 Subject: [PATCH] =?UTF-8?q?=ED=95=99=EC=8A=B5=EC=8B=A4=ED=96=89=20?= =?UTF-8?q?=EC=A3=BC=EC=84=9D=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/ModelTrainDetailCoreService.java | 6 +++++ .../core/ModelTrainJobCoreService.java | 23 +++++++++++++--- .../core/ModelTrainMngCoreService.java | 14 ++++++++-- .../postgres/entity/ModelTrainJobEntity.java | 3 +++ .../train/service/DockerTrainService.java | 20 +++++++++++--- .../service/ModelTestMetricsJobService.java | 16 ++---------- .../service/ModelTrainMetricsJobService.java | 14 +--------- .../train/service/TestJobService.java | 13 ++++++++++ .../train/service/TrainJobService.java | 26 ++++++++++++++++++- .../train/service/TrainJobWorker.java | 20 +++++++++++++- 10 files changed, 118 insertions(+), 37 deletions(-) diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java index 1839991..5afcce4 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java @@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService { return modelDetailRepository.getModelDetailSummary(uuid); } + /** + * 하이퍼 파리미터 요약정보 + * + * @param uuid 모델마스터 uuid + * @return + */ public HyperSummary getByModelHyperParamSummary(UUID uuid) { return modelDetailRepository.getByModelHyperParamSummary(uuid); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java index 4319fc2..2479d80 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -52,7 +52,12 @@ public class ModelTrainJobCoreService { /** 실행 시작 처리 */ @Transactional public void markRunning( - Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) { + Long jobId, + String containerName, + String logPath, + String lockedBy, + Integer totalEpoch, + String jobType) { ModelTrainJobEntity job = modelTrainJobRepository .findById(jobId) @@ -64,13 +69,19 @@ public class ModelTrainJobCoreService { job.setStartedDttm(ZonedDateTime.now()); job.setLockedDttm(ZonedDateTime.now()); job.setLockedBy(lockedBy); + job.setJobType(jobType); if (totalEpoch != null) { job.setTotalEpoch(totalEpoch); } } - /** 성공 처리 */ + /** + * 성공 처리 + * + * @param jobId + * @param exitCode + */ @Transactional public void markSuccess(Long jobId, int exitCode) { ModelTrainJobEntity job = @@ -83,7 +94,13 @@ public class ModelTrainJobCoreService { job.setFinishedDttm(ZonedDateTime.now()); } - /** 실패 처리 */ + /** + * 실패 처리 + * + * @param jobId + * @param exitCode + * @param errorMessage + */ @Transactional public void markFailed(Long jobId, Integer exitCode, String errorMessage) { ModelTrainJobEntity job = diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 048b403..b901d4c 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -384,7 +384,12 @@ public class ModelTrainMngCoreService { master.setStatusCd(TrainStatusType.COMPLETED.getId()); } - /** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ + /** + * step 1오류 처리(옵션) - Worker가 실패 시 호출 + * + * @param modelId + * @param errorMessage + */ @Transactional public void markError(Long modelId, String errorMessage) { ModelMasterEntity master = @@ -399,7 +404,12 @@ public class ModelTrainMngCoreService { master.setUpdatedDttm(ZonedDateTime.now()); } - /** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ + /** + * step 2오류 처리(옵션) - Worker가 실패 시 호출 + * + * @param modelId + * @param errorMessage + */ @Transactional public void markStep2Error(Long modelId, String errorMessage) { ModelMasterEntity master = diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java index ceb4889..dc352f6 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java @@ -83,6 +83,9 @@ public class ModelTrainJobEntity { @Column(name = "current_epoch") private Integer currentEpoch; + @Column(name = "job_type") + private String jobType; + public ModelTrainJobDto toDto() { return new ModelTrainJobDto( this.id, diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 2e5d7a4..5624530 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -52,7 +52,14 @@ public class DockerTrainService { private final ModelTrainJobCoreService modelTrainJobCoreService; - /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ + /** + * Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 + * + * @param req + * @param containerName + * @return + * @throws Exception + */ public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception { List cmd = buildDockerRunCommand(containerName, req); @@ -267,8 +274,7 @@ public class DockerTrainService { addArg(c, "--input-size", req.getInputSize()); addArg(c, "--crop-size", req.getCropSize()); addArg(c, "--batch-size", req.getBatchSize()); - addArg(c, "--gpu-ids", req.getGpuIds()); - // addArg(c, "--gpus", req.getGpus()); + addArg(c, "--gpu-ids", req.getGpuIds()); // null addArg(c, "--lr", req.getLearningRate()); addArg(c, "--backbone", req.getBackbone()); addArg(c, "--epochs", req.getEpochs()); @@ -342,6 +348,14 @@ public class DockerTrainService { } } + /** + * Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 + * + * @param containerName + * @param req + * @return + * @throws Exception + */ public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception { List cmd = buildDockerEvalCommand(containerName, req); diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index 3afd342..632d2be 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -48,20 +48,8 @@ public class ModelTestMetricsJobService { @Value("${file.pt-path}") private String ptPathDir; - /** - * 실행중인 profile - * - * @return - */ - private boolean isLocalProfile() { - return "local".equalsIgnoreCase(profile); - } - - // @Scheduled(cron = "0 * * * * *") - public void findTestValidMetricCsvFiles() throws IOException { - // if (isLocalProfile()) { - // return; - // } + /** 결과 csv 파일 정보 등록 */ + public void findTestValidMetricCsvFiles() { List modelIds = modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds(); diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index 933b9f1..d1b8ea4 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService { @Value("${train.docker.responseDir}") private String responseDir; - /** - * 실행중인 profile - * - * @return - */ - private boolean isLocalProfile() { - return "local".equalsIgnoreCase(profile); - } - - // @Scheduled(cron = "0 * * * * *") + /** 결과 csv 파일 정보 등록 */ public void findTrainValidMetricCsvFiles() { - // if (isLocalProfile()) { - // return; - // } List modelIds = modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds(); diff --git a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java index 3171677..a751b09 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java @@ -23,6 +23,14 @@ public class TestJobService { private final ApplicationEventPublisher eventPublisher; private final DataSetCountersService dataSetCounters; + /** + * 실행 예약 (QUEUE 등록) + * + * @param modelId + * @param uuid + * @param epoch + * @return + */ @Transactional public Long enqueue(Long modelId, UUID uuid, int epoch) { @@ -58,6 +66,11 @@ public class TestJobService { return jobId; } + /** + * 취소 + * + * @param modelId + */ @Transactional public void cancel(Long modelId) { diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index 9f29980..68214a9 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -47,7 +47,12 @@ public class TrainJobService { return modelTrainMngCoreService.findModelIdByUuid(uuid); } - /** 실행 예약 (QUEUE 등록) */ + /** + * 실행 예약 (QUEUE 등록) + * + * @param modelId + * @return + */ @Transactional public Long enqueue(Long modelId) { @@ -139,6 +144,13 @@ public class TrainJobService { modelTrainMngCoreService.markStopped(modelId); } + /** + * 학습 이어하기 + * + * @param modelId 모델 id + * @param mode NONE 새로 시작, REQUIRE 이어하기 + * @return + */ private Long createNextAttempt(Long modelId, ResumeMode mode) { ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); @@ -189,6 +201,12 @@ public class TrainJobService { REQUIRE // 이어하기 } + /** + * 이어하기 체크포인트 탐지해서 resumeFrom 세팅 + * + * @param paramsJson + * @return + */ public String findResumeFromOrNull(Map paramsJson) { if (paramsJson == null) return null; @@ -230,6 +248,12 @@ public class TrainJobService { } } + /** + * 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기 + * + * @param modelUuid + * @return + */ @Transactional public UUID createTmpFile(UUID modelUuid) { UUID tmpUuid = UUID.randomUUID(); diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 4cd1c03..5cc0b31 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -17,6 +17,7 @@ import org.springframework.stereotype.Component; import org.springframework.transaction.event.TransactionPhase; import org.springframework.transaction.event.TransactionalEventListener; +/** job 실행 */ @Log4j2 @Component @RequiredArgsConstructor @@ -54,6 +55,8 @@ public class TrainJobWorker { String containerName = (isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8); + String type = isEval ? "TEST" : "TRAIN"; + Integer totalEpoch = null; if (params.containsKey("totalEpoch")) { if (params.get("totalEpoch") != null) { @@ -61,12 +64,15 @@ public class TrainJobWorker { } } log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName); - modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch); + // 실행 시작 처리 + modelTrainJobCoreService.markRunning( + jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type); log.info("[JOB] markRunning done jobId={}", jobId); try { TrainRunResult result; if (isEval) { + // step2 진행중 처리 modelTrainMngCoreService.markStep2InProgress(modelId, jobId); String uuid = String.valueOf(params.get("uuid")); int epoch = (int) params.get("epoch"); @@ -81,11 +87,14 @@ public class TrainJobWorker { evalReq.setOutputFolder(outputFolder); log.info("[JOB] selected test epoch={}", epoch); + // 도커 실행 후 로그 수집 result = dockerTrainService.runEvalSync(containerName, evalReq); } else { + // step1 진행중 처리 modelTrainMngCoreService.markStep1InProgress(modelId, jobId); TrainRunRequest trainReq = toTrainRunRequest(params); + // 도커 실행 후 로그 수집 result = dockerTrainService.runTrainSync(trainReq, containerName); } @@ -99,24 +108,31 @@ public class TrainJobWorker { } if (result.getExitCode() == 0) { + // 성공 처리 modelTrainJobCoreService.markSuccess(jobId, result.getExitCode()); if (isEval) { + // step2 완료처리 modelTrainMngCoreService.markStep2Success(modelId); + // 결과 csv 파일 정보 등록 modelTestMetricsJobService.findTestValidMetricCsvFiles(); } else { modelTrainMngCoreService.markStep1Success(modelId); + // 결과 csv 파일 정보 등록 modelTrainMetricsJobService.findTrainValidMetricCsvFiles(); } } else { String failMsg = result.getStatus() + "\n" + result.getLogs(); + // 실패 처리 modelTrainJobCoreService.markFailed( jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); if (isEval) { + // 오류 정보 등록 modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode()); } else { + // 오류 정보 등록 modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode()); } } @@ -125,8 +141,10 @@ public class TrainJobWorker { modelTrainJobCoreService.markFailed(jobId, null, e.getMessage()); if ("EVAL".equals(params.get("jobType"))) { + // 오류 정보 등록 modelTrainMngCoreService.markStep2Error(modelId, e.getMessage()); } else { + // 오류 정보 등록 modelTrainMngCoreService.markError(modelId, e.getMessage()); } }