diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java index 55be88b..f595e97 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -176,6 +176,8 @@ public class ModelTrainMngDto { private String requestPath; private String responsePath; + private String tmpFileStatus; + private ZonedDateTime tmpFileEndDttm; } @Getter diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index 9cf0479..49b7864 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -305,7 +305,19 @@ public class ModelTrainMngService { modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig()); // 데이터셋 임시파일 생성 - trainJobService.createTmpFile(modelUuid); + List datasetList = null; + if (req.getTrainingDataset() != null) { + datasetList = req.getTrainingDataset().getDatasetList(); + } + + boolean isSingleDataset = datasetList != null && datasetList.size() == 1; + + // 데이터셋 1개만 선택한 경우는 symbolic link 미생성 해도 됨 -> train 호출 시 그냥 데이터셋 request 경로로 호출 + if (isSingleDataset) { + trainJobService.updateRequestPath(modelUuid, datasetList); + } else { + trainJobService.createTmpFile(modelUuid); + } return modelUuid; } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index ce6e014..6b54aa9 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -180,6 +180,9 @@ public class ModelTrainMngCoreService { if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) { entity.setRequestPath(req.getRequestPath()); } + + entity.setTmpFileStatus(req.getTmpFileStatus()); + entity.setTmpFileEndDttm(req.getTmpFileEndDttm()); } /** @@ -673,4 +676,36 @@ public class ModelTrainMngCoreService { public List findDatasetTestPath(Long modelId) { return modelDatasetMapRepository.findDatasetTestPath(modelId); } + + public void updateTrainRequestPath(Long modelId, String datasetUid) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + // 임시폴더 UID업데이트 + entity.setRequestPath(datasetUid); + entity.setReqTmpYn(false); // false 인 것은 train, test 실행 시 docker 명령어에 request 폴더를 바라보게 할 예정 + } + + public void updateTmpFileStatusStart(Long modelId) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + entity.setReqTmpYn(true); + entity.setTmpFileStatus("IN_PROGRESS"); + entity.setTmpFileStartDttm(ZonedDateTime.now()); + } + + public void updateTmpFileStatusFail(Long modelId, String message) { + ModelMasterEntity entity = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + entity.setTmpFileStatus("FAIL"); + entity.setTmpFileErrMessage(message); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java index 2609f3b..825ccec 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java @@ -121,6 +121,21 @@ public class ModelMasterEntity { @Column(name = "packing_end_dttm") private ZonedDateTime packingEndDttm; + @Column(name = "req_tmp_yn") + private Boolean reqTmpYn; + + @Column(name = "tmp_file_status") + private String tmpFileStatus; + + @Column(name = "tmp_file_start_dttm") + private ZonedDateTime tmpFileStartDttm; + + @Column(name = "tmp_file_end_dttm") + private ZonedDateTime tmpFileEndDttm; + + @Column(name = "tmp_file_err_message", columnDefinition = "TEXT") + private String tmpFileErrMessage; + public ModelTrainMngDto.Basic toDto() { return new ModelTrainMngDto.Basic( this.id, diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index 54eff48..7b0f7ee 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -189,7 +189,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { modelHyperParamEntity.hueDelta, Expressions.nullExpression(Integer.class), Expressions.nullExpression(String.class), - modelHyperParamEntity.uuid)) + modelHyperParamEntity.uuid, + modelMasterEntity.reqTmpYn)) .from(modelMasterEntity) .leftJoin(modelHyperParamEntity) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) diff --git a/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java index 794a1b8..3e35383 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java @@ -15,6 +15,7 @@ public class EvalRunRequest { private Integer timeoutSeconds; private String datasetFolder; private String outputFolder; + private Boolean reqTmpYn; public String getOutputFolder() { return this.outputFolder.toString(); diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java index 5761131..45faef9 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java @@ -84,6 +84,8 @@ public class TrainRunRequest { private UUID uuid; + private Boolean reqTmpYn; // tmp 심볼릭 링크를 쓰는지 아닌지 여부 + public String getOutputFolder() { return String.valueOf(this.outputFolder); } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index e33e83d..93783c6 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -271,7 +271,11 @@ public class DockerTrainService { c.add("-v"); c.add(basePath + ":" + basePath); // 심볼릭 링크와 연결되는 실제 파일 경로도 마운트를 해줘야 함 c.add("-v"); - c.add(symbolicDir + ":/data"); // 요청할경로 + if (req.getReqTmpYn()) { + c.add(symbolicDir + ":/data"); // 요청할경로 : tmp 심볼릭 사용하는 것이니 symbolicDir로 호출 + } else { + c.add(requestDir + ":/data"); // 요청할경로 : tmp 심볼릭 사용하지 않으니 request로 호출 + } c.add("-v"); c.add(responseDir + ":/checkpoints"); // 저장될경로 @@ -472,8 +476,13 @@ public class DockerTrainService { c.add("-v"); c.add(basePath + ":" + basePath); // 심볼릭 링크와 연결되는 실제 파일 경로도 마운트를 해줘야 함 + c.add("-v"); - c.add(basePath + "/tmp:/data"); + if (req.getReqTmpYn()) { + c.add(symbolicDir + ":/data"); // tmp 사용하는 모델은 심볼릭 링크 + } else { + c.add(requestDir + ":/data"); // tmp 사용하지 않는 모델은 request 경로 + } c.add("-v"); c.add(responseDir + ":/checkpoints"); diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index 14675cc..2341f15 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -269,7 +269,7 @@ public class ModelTestMetricsJobService { zipFiles(zipFileList, individualZipPath); log.info( - "✅ 개별 ZIP 생성 완료: fileName={}, pthFile={}, size={} bytes", + "개별 ZIP 생성 완료: fileName={}, pthFile={}, size={} bytes", individualZipName, pthFileName, Files.size(individualZipPath)); diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index 697bde9..2fc5585 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -204,6 +204,18 @@ public class TrainJobService { return jobId; } + /** + * 데이터셋 1개일 때, 파일 경로 업데이트 + * + * @param modelUuid + * @param datasetList + */ + public void updateRequestPath(UUID modelUuid, List datasetList) { + Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid); + List datasetUid = modelTrainMngCoreService.findDatasetUid(datasetList); + modelTrainMngCoreService.updateTrainRequestPath(modelId, datasetUid.getFirst()); + } + private enum ResumeMode { NONE, // 새로 시작 REQUIRE // 이어하기 @@ -274,6 +286,9 @@ public class TrainJobService { List uids = modelTrainMngCoreService.findDatasetUid(datasetIds); try { + // 1. 시작 상태 업데이트 + modelTrainMngCoreService.updateTmpFileStatusStart(modelId); + // 데이터셋 심볼링크 생성 // String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids); // train path 모델 클래스별 조회 @@ -298,6 +313,8 @@ public class TrainJobService { ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq(); updateReq.setRequestPath(raw); + updateReq.setTmpFileStatus("COMPLETE"); + updateReq.setTmpFileEndDttm(ZonedDateTime.now()); // 학습모델을 수정한다. modelTrainMngCoreService.updateModelMaster(modelId, updateReq); @@ -311,6 +328,9 @@ public class TrainJobService { (uids == null ? null : uids.size()), e); + // 3. 실패 처리 + modelTrainMngCoreService.updateTmpFileStatusFail(modelId, e.getMessage()); + // 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함 throw new CustomApiException( "INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다."); diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 968e977..0b3bf17 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -87,6 +87,7 @@ public class TrainJobWorker { evalReq.setTimeoutSeconds(null); evalReq.setDatasetFolder(datasetFolder); evalReq.setOutputFolder(outputFolder); + evalReq.setReqTmpYn((Boolean) params.get("reqTmpYn")); log.info("[JOB] selected test epoch={}", epoch); // 도커 실행 후 로그 수집