학습실행 step1 할 때 best epoch 업데이트 #86

Merged
gina merged 1 commits from feat/training_260202 into develop 2026-02-13 10:18:27 +09:00
4 changed files with 48 additions and 0 deletions

View File

@@ -29,4 +29,9 @@ public class ModelTrainMetricsJobCoreService {
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
}
@Transactional
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch);
}
}

View File

@@ -12,4 +12,6 @@ public interface ModelTrainMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
void insertModelMetricsValidation(List<Object[]> batchArgs);
void updateModelSelectedBestEpoch(Long modelId, Integer epoch);
}

View File

@@ -82,4 +82,13 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.bestEpoch, epoch)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
}

View File

@@ -6,9 +6,14 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
@@ -129,6 +134,33 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
Integer epoch = null;
boolean exists;
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
try (Stream<Path> s = Files.list(responsePath)) {
epoch =
s.filter(Files::isRegularFile)
.map(
p -> {
Matcher matcher = pattern.matcher(p.getFileName().toString());
if (matcher.matches()) {
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
}