Merge pull request 'best epoch 파일 선택 수정' (#132) from feat/training_260202 into develop
Reviewed-on: #132
This commit was merged in pull request #132.
This commit is contained in:
@@ -416,9 +416,8 @@ public class DockerTrainService {
|
||||
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
|
||||
|
||||
Path epochPath = Paths.get(responseDir, req.getOutputFolder());
|
||||
|
||||
Path checkpoint = findCheckpoint(epochPath, epoch);
|
||||
String modelFile = checkpoint.toString();
|
||||
// 결과 폴더에 파라미터로 받은 베스트 epoch이 best_changed_fscore_epoch_ 로 시작하는 파일이 있는지 확인 후 pth 파일명 반환
|
||||
String modelFile = findCheckpoint(epochPath, epoch);
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
@@ -450,7 +449,7 @@ public class DockerTrainService {
|
||||
return c;
|
||||
}
|
||||
|
||||
public Path findCheckpoint(Path dir, int epoch) {
|
||||
public String findCheckpoint(Path dir, int epoch) {
|
||||
|
||||
String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch);
|
||||
String normalFileName = String.format("epoch_%d.pth", epoch);
|
||||
@@ -460,12 +459,12 @@ public class DockerTrainService {
|
||||
|
||||
// 1. best 파일이 존재하면 그거 사용
|
||||
if (Files.isRegularFile(bestPath)) {
|
||||
return bestPath;
|
||||
return bestFileName;
|
||||
}
|
||||
|
||||
// 2. 없으면 일반 epoch 파일 사용
|
||||
if (Files.isRegularFile(normalPath)) {
|
||||
return normalPath;
|
||||
return normalFileName;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Checkpoint 파일이 없습니다. epoch=" + epoch);
|
||||
|
||||
Reference in New Issue
Block a user