diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java index 41ec17e..cba368d 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java @@ -163,11 +163,12 @@ public class ModelTrainDetailDto { this.compareYyyy = compareYyyy; this.targetYyyy = targetYyyy; this.roundNo = roundNo; - this.buildingCnt = buildingCnt; - this.containerCnt = containerCnt; - this.wasteCnt = wasteCnt; - this.landCoverCnt = landCoverCnt; - this.solarPanelCnt = solarPanelCnt; + this.buildingCnt = toNullIfZero(buildingCnt); + this.containerCnt = toNullIfZero(containerCnt); + this.wasteCnt = toNullIfZero(wasteCnt); + this.landCoverCnt = toNullIfZero(landCoverCnt); + this.solarPanelCnt = toNullIfZero(solarPanelCnt); + this.dataTypeName = getDataTypeName(this.dataType); } @@ -177,6 +178,10 @@ public class ModelTrainDetailDto { } } + private static Long toNullIfZero(Long value) { + return (value == null || value == 0L) ? null : value; + } + @Getter @Setter @NoArgsConstructor diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index aa66047..ce6e014 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -140,6 +140,7 @@ public class ModelTrainMngCoreService { * @param addReq 요청 파라미터 */ public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) { + TrainingDataset dataset = addReq.getTrainingDataset(); ModelMasterEntity modelMasterEntity = new ModelMasterEntity(); ModelDatasetEntity datasetEntity = new ModelDatasetEntity(); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java index 34cef08..573148b 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.postgres.repository.model; import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity; +import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity; import static com.kamco.cd.training.postgres.entity.QModelDatasetEntity.modelDatasetEntity; import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; @@ -9,6 +10,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.mode import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity; +import com.kamco.cd.training.common.enums.DetectionClassification; +import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; @@ -25,6 +28,7 @@ import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.QModelMasterEntity; import com.querydsl.core.types.Expression; import com.querydsl.core.types.Projections; +import com.querydsl.core.types.dsl.CaseBuilder; import com.querydsl.jpa.JPAExpressions; import com.querydsl.jpa.impl.JPAQueryFactory; import java.util.ArrayList; @@ -154,11 +158,78 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom { datasetEntity.compareYyyy, datasetEntity.targetYyyy, datasetEntity.roundNo, - modelDatasetEntity.buildingCnt, - modelDatasetEntity.containerCnt, - modelDatasetEntity.wasteCnt, - modelDatasetEntity.landCoverCnt, - modelDatasetEntity.solarCnt)) + + // G1 - building + new CaseBuilder() + .when( + modelMasterEntity + .modelNo + .eq(ModelType.G1.getId()) + .and( + datasetObjEntity.targetClassCd.eq( + DetectionClassification.BUILDING.getId()))) + .then(1L) + .otherwise(0L) + .sum(), + + // G1 - container + new CaseBuilder() + .when( + modelMasterEntity + .modelNo + .eq(ModelType.G1.getId()) + .and( + datasetObjEntity.targetClassCd.eq( + DetectionClassification.CONTAINER.getId()))) + .then(1L) + .otherwise(0L) + .sum(), + + // G2 - waste + new CaseBuilder() + .when( + modelMasterEntity + .modelNo + .eq(ModelType.G2.getId()) + .and( + datasetObjEntity.targetClassCd.eq( + DetectionClassification.WASTE.getId()))) + .then(1L) + .otherwise(0L) + .sum(), + + // G3 - 나머지 + new CaseBuilder() + .when( + modelMasterEntity + .modelNo + .eq(ModelType.G3.getId()) + .and( + datasetObjEntity + .targetClassCd + .isNotNull() + .and( + datasetObjEntity.targetClassCd.notIn( + DetectionClassification.BUILDING.getId(), + DetectionClassification.CONTAINER.getId(), + DetectionClassification.WASTE.getId(), + DetectionClassification.SOLAR.getId())))) + .then(1L) + .otherwise(0L) + .sum(), + + // G4 - solar + new CaseBuilder() + .when( + modelMasterEntity + .modelNo + .eq(ModelType.G4.getId()) + .and( + datasetObjEntity.targetClassCd.eq( + DetectionClassification.SOLAR.getId()))) + .then(1L) + .otherwise(0L) + .sum())) .from(modelMasterEntity) .innerJoin(modelDatasetEntity) .on(modelMasterEntity.id.eq(modelDatasetEntity.model.id)) @@ -166,7 +237,16 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom { .on(modelMasterEntity.id.eq(modelDatasetMappEntity.modelUid)) .innerJoin(datasetEntity) .on(modelDatasetMappEntity.datasetUid.eq(datasetEntity.id)) + .leftJoin(datasetObjEntity) + .on(datasetEntity.id.eq(datasetObjEntity.datasetUid)) .where(modelMasterEntity.uuid.eq(uuid)) + .groupBy( + modelMasterEntity.id, + datasetEntity.id, + datasetEntity.dataType, + datasetEntity.compareYyyy, + datasetEntity.targetYyyy, + datasetEntity.roundNo) .fetch(); }