From 79e8259f28cfa955c138f36a0b588d2ff9696c21 Mon Sep 17 00:00:00 2001 From: teddy Date: Thu, 12 Feb 2026 21:30:03 +0900 Subject: [PATCH] =?UTF-8?q?=ED=8C=8C=EB=9D=BC=EB=AF=B8=ED=84=B0=20?= =?UTF-8?q?=EB=B3=80=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cd/training/train/dto/EvalRunRequest.java | 6 ++++++ .../train/service/DockerTrainService.java | 17 ++++++++--------- .../training/train/service/TestJobService.java | 6 ++++++ .../training/train/service/TrainJobWorker.java | 10 +++++++++- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java index fe621eb..794a1b8 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/EvalRunRequest.java @@ -13,4 +13,10 @@ public class EvalRunRequest { private String uuid; private int epoch; // best_changed_fscore_epoch_1.pth private Integer timeoutSeconds; + private String datasetFolder; + private String outputFolder; + + public String getOutputFolder() { + return this.outputFolder.toString(); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index ba9b76b..31b32cc 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -311,6 +311,7 @@ public class DockerTrainService { addArg(c, "--hue-delta", req.getHueDelta()); addArg(c, "--resume-from", req.getResumeFrom()); + addArg(c, "--save-interval", 1); return c; } @@ -414,30 +415,28 @@ public class DockerTrainService { c.add("docker"); c.add("run"); - c.add("--name"); - c.add(containerName); c.add("--rm"); - c.add("--gpus"); c.add("all"); - if (ipcHost) c.add("--ipc=host"); + c.add("--ipc=host"); c.add("--shm-size=" + shmSize); c.add("-v"); c.add("/home/kcomu/data" + "/tmp:/data"); + c.add("-v"); c.add(responseDir + ":/checkpoints"); - c.add(image); + c.add("kamco-cd-train:latest"); c.add("python"); c.add("/workspace/change-detection-code/run_evaluation_pipeline.py"); - c.add("--dataset_dir"); - c.add("/data/" + uuid); + addArg(c, "--dataset-folder", req.getDatasetFolder()); + addArg(c, "--output-folder", req.getOutputFolder()); - c.add("--model"); - c.add("/checkpoints/" + uuid + "/" + modelFile); + c.add("--epoch"); + c.add(modelFile); return c; } diff --git a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java index a7cf035..c34b188 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java @@ -5,6 +5,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; +import com.kamco.cd.training.train.dto.TrainRunRequest; import java.time.ZonedDateTime; import java.util.Map; import java.util.UUID; @@ -32,10 +33,15 @@ public class TestJobService { // best epoch 업데이트 modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch); + // 파라미터 조회 + TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId); + Map params = new java.util.LinkedHashMap<>(); params.put("jobType", "EVAL"); params.put("uuid", String.valueOf(uuid)); params.put("epoch", epoch); + params.put("datasetFolder", trainRunRequest.getDatasetFolder()); + params.put("outputFolder", trainRunRequest.getOutputFolder()); int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index bb30a3f..44dbe1d 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -68,8 +68,16 @@ public class TrainJobWorker { modelTrainMngCoreService.markStep2InProgress(modelId, jobId); String uuid = String.valueOf(params.get("uuid")); int epoch = (int) params.get("epoch"); + String datasetFolder = String.valueOf(params.get("datasetFolder")); + String outputFolder = String.valueOf(params.get("outputFolder")); + + EvalRunRequest evalReq = new EvalRunRequest(); + evalReq.setUuid(uuid); + evalReq.setEpoch(epoch); + evalReq.setTimeoutSeconds(null); + evalReq.setDatasetFolder(datasetFolder); + evalReq.setOutputFolder(outputFolder); - EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null); result = dockerTrainService.runEvalSync(evalReq, containerName); } else {