diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 4f61ab2..b1c029b 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -7,6 +7,8 @@ import com.kamco.cd.training.train.dto.TrainRunResult; import java.io.BufferedReader; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; @@ -339,7 +341,7 @@ public class DockerTrainService { } } - public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception { + public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception { List cmd = buildDockerEvalCommand(containerName, req); @@ -412,7 +414,12 @@ public class DockerTrainService { if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required"); if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0"); - String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth"; + String modelFile = + String.valueOf(findCheckpoint(Path.of(responseDir + req.getOutputFolder()), epoch)); + + if (modelFile == null || modelFile.isBlank()) { + throw new IllegalArgumentException("best model file is required"); + } List c = new ArrayList<>(); @@ -443,4 +450,26 @@ public class DockerTrainService { return c; } + + public Path findCheckpoint(Path dir, int epoch) { + + String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch); + String normalFileName = String.format("epoch_%d.pth", epoch); + + Path bestPath = dir.resolve(bestFileName); + Path normalPath = dir.resolve(normalFileName); + + // 1. best 파일이 존재하면 그거 사용 + if (Files.isRegularFile(bestPath)) { + return bestPath; + } + + // 2. 없으면 일반 epoch 파일 사용 + if (Files.isRegularFile(normalPath)) { + return normalPath; + } + + // 둘 다 없으면 null 또는 예외 + return null; + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 22386a2..4cd1c03 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -80,7 +80,8 @@ public class TrainJobWorker { evalReq.setDatasetFolder(datasetFolder); evalReq.setOutputFolder(outputFolder); log.info("[JOB] selected test epoch={}", epoch); - result = dockerTrainService.runEvalSync(evalReq, containerName); + + result = dockerTrainService.runEvalSync(containerName, evalReq); } else { modelTrainMngCoreService.markStep1InProgress(modelId, jobId);