diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java index d7823b1..3f592d8 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java @@ -29,4 +29,9 @@ public class ModelTrainMetricsJobCoreService { public void insertModelMetricsValidation(List batchArgs) { modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs); } + + @Transactional + public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) { + modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java index 67517fe..f4031bf 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java @@ -12,4 +12,6 @@ public interface ModelTrainMetricsJobRepositoryCustom { void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); void insertModelMetricsValidation(List batchArgs); + + void updateModelSelectedBestEpoch(Long modelId, Integer epoch); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java index fb5916c..1323d40 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java @@ -82,4 +82,13 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor jdbcTemplate.batchUpdate(sql, batchArgs); } + + @Override + public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) { + queryFactory + .update(modelMasterEntity) + .set(modelMasterEntity.bestEpoch, epoch) + .where(modelMasterEntity.id.eq(modelId)) + .execute(); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index b8a05b0..4487973 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -6,9 +6,14 @@ import java.io.BufferedReader; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Stream; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.csv.CSVFormat; @@ -129,6 +134,33 @@ public class ModelTrainMetricsJobService { throw new RuntimeException(e); } + Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); + Integer epoch = null; + boolean exists; + Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth"); + + try (Stream s = Files.list(responsePath)) { + epoch = + s.filter(Files::isRegularFile) + .map( + p -> { + Matcher matcher = pattern.matcher(p.getFileName().toString()); + if (matcher.matches()) { + return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출 + } + return null; + }) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + // best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기 + modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch); + modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn( modelInfo.getModelId(), "step1"); }