diff --git a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java index 84792da..1dff752 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java @@ -3,6 +3,9 @@ package com.kamco.cd.training.model; import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.service.ModelTrainDetailService; @@ -132,4 +135,69 @@ public class ModelTrainDetailApiController { UUID uuid) { return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid)); } + + @Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = TransferDetailDto.class))), + @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/metrics/train/{uuid}") + public ApiResponseDto> getModelTrainMetricResult( + @Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf") + @PathVariable + UUID uuid) { + return ApiResponseDto.ok(modelTrainDetailService.getModelTrainMetricResult(uuid)); + } + + @Operation( + summary = "모델관리 > 모델 상세 > 성능 정보 (Validation)", + description = "모델 상세 > 성능 정보 (Validation) API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = TransferDetailDto.class))), + @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/metrics/validation/{uuid}") + public ApiResponseDto> getModelValidationMetricResult( + @Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf") + @PathVariable + UUID uuid) { + return ApiResponseDto.ok(modelTrainDetailService.getModelValidationMetricResult(uuid)); + } + + @Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = TransferDetailDto.class))), + @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/metrics/test/{uuid}") + public ApiResponseDto> getModelTestMetricResult( + @Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf") + @PathVariable + UUID uuid) { + return ApiResponseDto.ok(modelTrainDetailService.getModelTestMetricResult(uuid)); + } } diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java index 8a854d8..6cdfe17 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java @@ -180,4 +180,55 @@ public class ModelTrainDetailDto { private TransferHyperSummary modelTrainHyper; private List modelTrainDataset; } + + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class ModelTrainMetrics { + private Integer epoch; + private Long iteration; + private Double loss; + private Double lr; + private Float durationTime; + } + + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class ModelValidationMetrics { + + private Integer epoch; + private Float aAcc; + private Float mFscore; + private Float mPrecision; + private Float mRecall; + private Float mIou; + private Float mAcc; + private Float changedFscore; + private Float changedPrecision; + private Float changedRecall; + private Float unchangedFscore; + private Float unchangedPrecision; + private Float unchangedRecall; + } + + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class ModelTestMetrics { + private String model; + private Long tp; + private Long fp; + private Long fn; + private Float precision; + private Float recall; + private Float f1Score; + private Float accuracy; + private Float iou; + private Long detectionCount; + private Long gtCount; + } } diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java index 6e53452..57c5b26 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java @@ -6,6 +6,9 @@ import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; @@ -96,4 +99,16 @@ public class ModelTrainDetailService { return transferDetailDto; } + + public List getModelTrainMetricResult(UUID uuid) { + return modelTrainDetailCoreService.getModelTrainMetricResult(uuid); + } + + public List getModelValidationMetricResult(UUID uuid) { + return modelTrainDetailCoreService.getModelValidationMetricResult(uuid); + } + + public List getModelTestMetricResult(UUID uuid) { + return modelTrainDetailCoreService.getModelTestMetricResult(uuid); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java index bb38f57..6da6aba 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java @@ -7,6 +7,9 @@ import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; @@ -77,4 +80,16 @@ public class ModelTrainDetailCoreService { public ModelConfigDto.Basic findModelConfig(Long modelId) { return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null); } + + public List getModelTrainMetricResult(UUID uuid) { + return modelDetailRepository.getModelTrainMetricResult(uuid); + } + + public List getModelValidationMetricResult(UUID uuid) { + return modelDetailRepository.getModelValidationMetricResult(uuid); + } + + public List getModelTestMetricResult(UUID uuid) { + return modelDetailRepository.getModelTestMetricResult(uuid); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java index 1af36a4..21ae644 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java @@ -3,6 +3,9 @@ package com.kamco.cd.training.postgres.repository.model; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import java.util.List; @@ -22,4 +25,10 @@ public interface ModelDetailRepositoryCustom { List getByModelMappingDataset(UUID uuid); ModelMasterEntity findByModelByUUID(UUID uuid); + + List getModelTrainMetricResult(UUID uuid); + + List getModelValidationMetricResult(UUID uuid); + + List getModelTestMetricResult(UUID uuid); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java index 66b350e..edd556d 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java @@ -5,10 +5,16 @@ import static com.kamco.cd.training.postgres.entity.QModelDatasetEntity.modelDat import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; +import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity; +import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity; +import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity; @@ -154,4 +160,81 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom { .where(modelMasterEntity.uuid.eq(uuid)) .fetchOne(); } + + @Override + public List getModelTrainMetricResult(UUID uuid) { + ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid); + if (modelMasterEntity == null) { + return List.of(); + } + + return queryFactory + .select( + Projections.constructor( + ModelTrainMetrics.class, + modelMetricsTrainEntity.epoch, + modelMetricsTrainEntity.iteration, + modelMetricsTrainEntity.loss, + modelMetricsTrainEntity.lr, + modelMetricsTrainEntity.durationTime)) + .from(modelMetricsTrainEntity) + .where(modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId())) + .fetch(); + } + + @Override + public List getModelValidationMetricResult(UUID uuid) { + ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid); + if (modelMasterEntity == null) { + return List.of(); + } + + return queryFactory + .select( + Projections.constructor( + ModelValidationMetrics.class, + modelMetricsValidationEntity.epoch, + modelMetricsValidationEntity.aAcc, + modelMetricsValidationEntity.mFscore, + modelMetricsValidationEntity.mPrecision, + modelMetricsValidationEntity.mRecall, + modelMetricsValidationEntity.mIou, + modelMetricsValidationEntity.mAcc, + modelMetricsValidationEntity.changedFscore, + modelMetricsValidationEntity.changedPrecision, + modelMetricsValidationEntity.changedRecall, + modelMetricsValidationEntity.unchangedFscore, + modelMetricsValidationEntity.unchangedPrecision, + modelMetricsValidationEntity.unchangedRecall)) + .from(modelMetricsValidationEntity) + .where(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.getId())) + .fetch(); + } + + @Override + public List getModelTestMetricResult(UUID uuid) { + ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid); + if (modelMasterEntity == null) { + return List.of(); + } + + return queryFactory + .select( + Projections.constructor( + ModelTestMetrics.class, + modelMetricsTestEntity.model1, + modelMetricsTestEntity.tp, + modelMetricsTestEntity.fp, + modelMetricsTestEntity.fn, + modelMetricsTestEntity.precisions, + modelMetricsTestEntity.recall, + modelMetricsTestEntity.f1Score, + modelMetricsTestEntity.accuracy, + modelMetricsTestEntity.iou, + modelMetricsTestEntity.detectionCount, + modelMetricsTestEntity.gtCount)) + .from(modelMetricsTestEntity) + .where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId())) + .fetch(); + } }