diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 96be3f9..2e5d7a4 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -416,9 +416,8 @@ public class DockerTrainService { if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0"); Path epochPath = Paths.get(responseDir, req.getOutputFolder()); - - Path checkpoint = findCheckpoint(epochPath, epoch); - String modelFile = checkpoint.toString(); + // 결과 폴더에 파라미터로 받은 베스트 epoch이 best_changed_fscore_epoch_ 로 시작하는 파일이 있는지 확인 후 pth 파일명 반환 + String modelFile = findCheckpoint(epochPath, epoch); List c = new ArrayList<>(); @@ -450,7 +449,7 @@ public class DockerTrainService { return c; } - public Path findCheckpoint(Path dir, int epoch) { + public String findCheckpoint(Path dir, int epoch) { String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch); String normalFileName = String.format("epoch_%d.pth", epoch); @@ -460,12 +459,12 @@ public class DockerTrainService { // 1. best 파일이 존재하면 그거 사용 if (Files.isRegularFile(bestPath)) { - return bestPath; + return bestFileName; } // 2. 없으면 일반 epoch 파일 사용 if (Files.isRegularFile(normalPath)) { - return normalPath; + return normalFileName; } throw new IllegalStateException("Checkpoint 파일이 없습니다. epoch=" + epoch);