best epoch 파일 선택 수정
This commit is contained in:
@@ -7,6 +7,8 @@ import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
@@ -339,7 +341,7 @@ public class DockerTrainService {
|
||||
}
|
||||
}
|
||||
|
||||
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
|
||||
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||
|
||||
@@ -412,7 +414,12 @@ public class DockerTrainService {
|
||||
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
|
||||
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
|
||||
|
||||
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
|
||||
String modelFile =
|
||||
String.valueOf(findCheckpoint(Path.of(responseDir + req.getOutputFolder()), epoch));
|
||||
|
||||
if (modelFile == null || modelFile.isBlank()) {
|
||||
throw new IllegalArgumentException("best model file is required");
|
||||
}
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
@@ -443,4 +450,26 @@ public class DockerTrainService {
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
public Path findCheckpoint(Path dir, int epoch) {
|
||||
|
||||
String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch);
|
||||
String normalFileName = String.format("epoch_%d.pth", epoch);
|
||||
|
||||
Path bestPath = dir.resolve(bestFileName);
|
||||
Path normalPath = dir.resolve(normalFileName);
|
||||
|
||||
// 1. best 파일이 존재하면 그거 사용
|
||||
if (Files.isRegularFile(bestPath)) {
|
||||
return bestPath;
|
||||
}
|
||||
|
||||
// 2. 없으면 일반 epoch 파일 사용
|
||||
if (Files.isRegularFile(normalPath)) {
|
||||
return normalPath;
|
||||
}
|
||||
|
||||
// 둘 다 없으면 null 또는 예외
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +80,8 @@ public class TrainJobWorker {
|
||||
evalReq.setDatasetFolder(datasetFolder);
|
||||
evalReq.setOutputFolder(outputFolder);
|
||||
log.info("[JOB] selected test epoch={}", epoch);
|
||||
result = dockerTrainService.runEvalSync(evalReq, containerName);
|
||||
|
||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
|
||||
Reference in New Issue
Block a user