베스트 에폭 API, 2단계 실행 시 best epoch 업데이트
This commit is contained in:
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -92,4 +93,8 @@ public class ModelTrainDetailCoreService {
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
return modelDetailRepository.getModelTestMetricResult(uuid);
|
||||
}
|
||||
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -457,4 +457,13 @@ public class ModelTrainMngCoreService {
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setBestEpoch(epoch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,9 @@ public class ModelMasterEntity {
|
||||
@Column(name = "last_error")
|
||||
private String lastError;
|
||||
|
||||
@Column(name = "best_epoch")
|
||||
private Integer bestEpoch;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -31,4 +32,6 @@ public interface ModelDetailRepositoryCustom {
|
||||
List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid);
|
||||
|
||||
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
|
||||
|
||||
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntit
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -26,8 +27,10 @@ import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Slf4j
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
@@ -237,4 +240,33 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
.where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelBestEpoch.class,
|
||||
modelMetricsTrainEntity.epoch,
|
||||
modelMetricsTrainEntity.loss,
|
||||
modelMetricsValidationEntity.mFscore,
|
||||
modelMetricsValidationEntity.mPrecision,
|
||||
modelMetricsValidationEntity.mRecall,
|
||||
modelMetricsValidationEntity.mIou,
|
||||
modelMetricsValidationEntity.mAcc))
|
||||
.from(modelMetricsTrainEntity)
|
||||
.leftJoin(modelMetricsValidationEntity)
|
||||
.on(
|
||||
modelMetricsTrainEntity.model.eq(modelMetricsValidationEntity.model),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMetricsValidationEntity.epoch))
|
||||
.where(
|
||||
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user