Merge pull request '학습실행 step1 할 때 best epoch 업데이트' (#86) from feat/training_260202 into develop
Reviewed-on: #86
This commit was merged in pull request #86.
This commit is contained in:
@@ -29,4 +29,9 @@ public class ModelTrainMetricsJobCoreService {
|
|||||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||||
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
|
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
|
||||||
|
modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,4 +12,6 @@ public interface ModelTrainMetricsJobRepositoryCustom {
|
|||||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||||
|
|
||||||
void insertModelMetricsValidation(List<Object[]> batchArgs);
|
void insertModelMetricsValidation(List<Object[]> batchArgs);
|
||||||
|
|
||||||
|
void updateModelSelectedBestEpoch(Long modelId, Integer epoch);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,4 +82,13 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
|||||||
|
|
||||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
|
||||||
|
queryFactory
|
||||||
|
.update(modelMasterEntity)
|
||||||
|
.set(modelMasterEntity.bestEpoch, epoch)
|
||||||
|
.where(modelMasterEntity.id.eq(modelId))
|
||||||
|
.execute();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,14 @@ import java.io.BufferedReader;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
import java.util.stream.Stream;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.csv.CSVFormat;
|
import org.apache.commons.csv.CSVFormat;
|
||||||
@@ -129,6 +134,33 @@ public class ModelTrainMetricsJobService {
|
|||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||||
|
Integer epoch = null;
|
||||||
|
boolean exists;
|
||||||
|
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
|
||||||
|
|
||||||
|
try (Stream<Path> s = Files.list(responsePath)) {
|
||||||
|
epoch =
|
||||||
|
s.filter(Files::isRegularFile)
|
||||||
|
.map(
|
||||||
|
p -> {
|
||||||
|
Matcher matcher = pattern.matcher(p.getFileName().toString());
|
||||||
|
if (matcher.matches()) {
|
||||||
|
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
})
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
|
||||||
|
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
||||||
modelInfo.getModelId(), "step1");
|
modelInfo.getModelId(), "step1");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user