From 3367d0e7be7abc76d11055dfaf0f0f849042bc5f Mon Sep 17 00:00:00 2001 From: teddy Date: Thu, 12 Feb 2026 10:14:32 +0900 Subject: [PATCH] =?UTF-8?q?=EC=8B=A4=ED=96=89=20=EC=98=A4=EB=A5=98=20?= =?UTF-8?q?=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/ModelTrainMngCoreService.java | 26 +++++++++++++++ .../model/ModelMngRepositoryImpl.java | 4 +-- .../cd/training/train/TrainApiController.java | 12 +++---- .../training/train/dto/TrainRunRequest.java | 13 ++++++-- .../train/service/DockerTrainService.java | 33 +++++++++++-------- .../train/service/TrainJobService.java | 2 -- .../train/service/TrainJobWorker.java | 2 ++ 7 files changed, 67 insertions(+), 25 deletions(-) diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 1098bcb..b08c633 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -31,6 +31,7 @@ import lombok.RequiredArgsConstructor; import org.springframework.data.domain.Page; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Propagation; import org.springframework.transaction.annotation.Transactional; @Service @@ -378,6 +379,13 @@ public class ModelTrainMngCoreService { return modelMngRepository.findTrainRunRequest(modelId); } + /** + * step1 진행중 처리 + * + * @param modelId + * @param jobId + */ + @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep1InProgress(Long modelId, Long jobId) { ModelMasterEntity entity = modelMngRepository @@ -392,6 +400,12 @@ public class ModelTrainMngCoreService { entity.setUpdatedUid(userUtil.getId()); } + /** + * step2 진행중 처리 + * + * @param modelId + */ + @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep2InProgress(Long modelId, Long jobId) { ModelMasterEntity entity = modelMngRepository @@ -406,6 +420,12 @@ public class ModelTrainMngCoreService { entity.setUpdatedUid(userUtil.getId()); } + /** + * step1 완료처리 + * + * @param modelId + */ + @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep1Success(Long modelId) { ModelMasterEntity entity = modelMngRepository @@ -419,6 +439,12 @@ public class ModelTrainMngCoreService { entity.setUpdatedUid(userUtil.getId()); } + /** + * step2 완료처리 + * + * @param modelId + */ + @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep2Success(Long modelId) { ModelMasterEntity entity = modelMngRepository diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index c540a70..2fe22be 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -90,7 +90,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { @Override public TrainRunRequest findTrainRunRequest(Long modelId) { - queryFactory + return queryFactory .select( Projections.constructor( TrainRunRequest.class, @@ -141,7 +141,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) .leftJoin(modelConfigEntity) .on(modelConfigEntity.model.id.eq(modelMasterEntity.id)) + .where(modelMasterEntity.id.eq(modelId)) .fetchOne(); - return null; } } diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java index 81c1e74..58b5341 100644 --- a/src/main/java/com/kamco/cd/training/train/TrainApiController.java +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -41,7 +41,7 @@ public class TrainApiController { }) @PostMapping("/run/{uuid}") public ApiResponseDto run( - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); @@ -64,7 +64,7 @@ public class TrainApiController { }) @PostMapping("/restart/{uuid}") public ApiResponseDto restart( - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); @@ -87,7 +87,7 @@ public class TrainApiController { }) @PostMapping("/resume/{uuid}") public ApiResponseDto resume( - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); @@ -110,7 +110,7 @@ public class TrainApiController { }) @PostMapping("/cancel/{uuid}") public ApiResponseDto cancel( - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); @@ -134,7 +134,7 @@ public class TrainApiController { @PostMapping("/test/run/{epoch}/{uuid}") public ApiResponseDto run( @Parameter(description = "best 에폭", example = "1") @PathVariable int epoch, - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); @@ -157,7 +157,7 @@ public class TrainApiController { }) @PostMapping("/test/cancel/{uuid}") public ApiResponseDto cancelTest( - @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java index a3b63d8..1e1974c 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.train.dto; +import java.util.UUID; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; @@ -14,8 +15,8 @@ public class TrainRunRequest { // ======================== // 기본 // ======================== - private String datasetFolder; - private String outputFolder; + private UUID datasetFolder; + private UUID outputFolder; private String inputSize; private String cropSize; private Integer batchSize; @@ -80,4 +81,12 @@ public class TrainRunRequest { // ======================== private Integer timeoutSeconds; private String resumeFrom; + + public String getDatasetFolder() { + return String.valueOf(datasetFolder); + } + + public String getOutputFolder() { + return String.valueOf(outputFolder); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 02a1443..09ea43e 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -9,9 +9,11 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; +import lombok.extern.log4j.Log4j2; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +@Log4j2 @Service public class DockerTrainService { @@ -44,6 +46,11 @@ public class DockerTrainService { List cmd = buildDockerRunCommand(containerName, req); + log.info("=== Docker Train Command ==="); + log.info("Container: {}", containerName); + log.info("Command: {}", String.join(" ", cmd)); + log.info("================================"); + ProcessBuilder pb = new ProcessBuilder(cmd); pb.redirectErrorStream(true); @@ -121,23 +128,11 @@ public class DockerTrainService { // 컨테이너 이름 지정 c.add("--name"); - c.add(containerName); + c.add(containerName + "-" + req.getOutputFolder().substring(0, 8)); // 실행 종료 시 자동 삭제 c.add("--rm"); - // 환경변수 설정 - c.add("-e"); - c.add("OPENCV_LOG_LEVEL=ERROR"); - c.add("-e"); - c.add("NCCL_DEBUG=INFO"); - c.add("-e"); - c.add("NCCL_IB_DISABLE=1"); - c.add("-e"); - c.add("NCCL_P2P_DISABLE=0"); - c.add("-e"); - c.add("NCCL_SOCKET_IFNAME=eth0"); - // GPU 전체 사용 c.add("--gpus"); c.add("all"); @@ -156,6 +151,18 @@ public class DockerTrainService { c.add("--ulimit"); c.add("stack=67108864"); + // 환경변수 설정 + c.add("-e"); + c.add("OPENCV_LOG_LEVEL=ERROR"); + c.add("-e"); + c.add("NCCL_DEBUG=INFO"); + c.add("-e"); + c.add("NCCL_IB_DISABLE=1"); + c.add("-e"); + c.add("NCCL_P2P_DISABLE=0"); + c.add("-e"); + c.add("NCCL_SOCKET_IFNAME=eth0"); + // 요청/결과 디렉토리 볼륨 마운트 c.add("-v"); c.add(requestDir + ":/data"); diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index e6f57b6..a5816b5 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -67,8 +67,6 @@ public class TrainJobService { // 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함) eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); - - modelTrainMngCoreService.markStep1InProgress(modelId, jobId); return jobId; } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 4eedc9a..2e4a094 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -55,6 +55,7 @@ public class TrainJobWorker { TrainRunResult result; if (isEval) { + modelTrainMngCoreService.markStep2InProgress(modelId, jobId); String uuid = String.valueOf(params.get("uuid")); int epoch = (int) params.get("epoch"); @@ -62,6 +63,7 @@ public class TrainJobWorker { result = dockerTrainService.runEvalSync(evalReq, containerName); } else { + modelTrainMngCoreService.markStep1InProgress(modelId, jobId); TrainRunRequest trainReq = toTrainRunRequest(params); result = dockerTrainService.runTrainSync(trainReq, containerName); } -- 2.49.1