Merge pull request 'feat/training_260202' (#46) from feat/training_260202 into develop
Reviewed-on: #46
This commit was merged in pull request #46.
This commit is contained in:
@@ -3,6 +3,7 @@ package com.kamco.cd.training.model;
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -200,4 +201,25 @@ public class ModelTrainDetailApiController {
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTestMetricResult(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/best-epoch/{uuid}")
|
||||
public ApiResponseDto<ModelBestEpoch> getModelTrainBestEpoch(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,4 +231,18 @@ public class ModelTrainDetailDto {
|
||||
private Long detectionCount;
|
||||
private Long gtCount;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelBestEpoch {
|
||||
private Integer epoch;
|
||||
private Double loss;
|
||||
private Float f1Score;
|
||||
private Float precision;
|
||||
private Float recall;
|
||||
private Float iou;
|
||||
private Float accuracy;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -111,4 +112,8 @@ public class ModelTrainDetailService {
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTestMetricResult(uuid);
|
||||
}
|
||||
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -92,4 +93,8 @@ public class ModelTrainDetailCoreService {
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
return modelDetailRepository.getModelTestMetricResult(uuid);
|
||||
}
|
||||
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -457,4 +457,13 @@ public class ModelTrainMngCoreService {
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setBestEpoch(epoch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +103,9 @@ public class ModelMasterEntity {
|
||||
@Column(name = "last_error")
|
||||
private String lastError;
|
||||
|
||||
@Column(name = "best_epoch")
|
||||
private Integer bestEpoch;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -31,4 +32,6 @@ public interface ModelDetailRepositoryCustom {
|
||||
List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid);
|
||||
|
||||
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
|
||||
|
||||
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntit
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -26,8 +27,10 @@ import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Slf4j
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
@@ -237,4 +240,33 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
.where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelBestEpoch.class,
|
||||
modelMetricsTrainEntity.epoch,
|
||||
modelMetricsTrainEntity.loss,
|
||||
modelMetricsValidationEntity.mFscore,
|
||||
modelMetricsValidationEntity.mPrecision,
|
||||
modelMetricsValidationEntity.mRecall,
|
||||
modelMetricsValidationEntity.mIou,
|
||||
modelMetricsValidationEntity.mAcc))
|
||||
.from(modelMetricsTrainEntity)
|
||||
.leftJoin(modelMetricsValidationEntity)
|
||||
.on(
|
||||
modelMetricsTrainEntity.model.eq(modelMetricsValidationEntity.model),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMetricsValidationEntity.epoch))
|
||||
.where(
|
||||
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,9 @@ public class TestJobService {
|
||||
// 마스터 확인
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// best epoch 업데이트
|
||||
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
||||
|
||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||
params.put("jobType", "EVAL");
|
||||
params.put("uuid", String.valueOf(uuid));
|
||||
|
||||
Reference in New Issue
Block a user