feat/training_260202 #133

Merged
teddy merged 2 commits from feat/training_260202 into develop 2026-02-20 18:22:45 +09:00
13 changed files with 559 additions and 56 deletions
Showing only changes of commit 07429dbe8e - Show all commits

View File

@@ -1,7 +1,6 @@
package com.kamco.cd.training.dataset.dto; package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister; import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType; import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
@@ -234,7 +233,6 @@ public class DatasetDto {
@Getter @Getter
@Setter @Setter
@NoArgsConstructor @NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet { public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입 private String modelNo; // G1, G2, G3 모델 타입
@@ -317,6 +315,183 @@ public class DatasetDto {
} }
} }
@Schema(name = "SelectTransferDataSet", description = "전이학습 데이터셋 선택 리스트")
@Getter
@Setter
@NoArgsConstructor
public static class SelectTransferDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
private String title;
private Long roundNo;
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
private String beforeModelNo; // G1, G2, G3 모델 타입
private Long beforeDatasetId;
private UUID beforeUuid;
private String beforeDataType;
private String beforeTitle;
private Long beforeRoundNo;
private Integer beforeCompareYyyy;
private Integer beforeTargetYyyy;
private String beforeMemo;
@JsonIgnore private Long beforeClassCount;
private Integer beforeBuildingCnt;
private Integer beforeContainerCnt;
private String beforeDataTypeName;
private Long beforeWasteCnt;
private Long beforeLandCoverCnt;
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Long classCount,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Long beforeClassCount) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo != null && modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo != null && modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeClassCount = beforeClassCount;
if (beforeModelNo != null && beforeModelNo.equals(ModelType.G2.getId())) {
this.beforeWasteCnt = beforeClassCount;
} else if (beforeModelNo != null && beforeModelNo.equals(ModelType.G3.getId())) {
this.beforeLandCoverCnt = beforeClassCount;
}
}
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCnt,
Integer containerCnt,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Integer beforeBuildingCnt,
Integer beforeContainerCnt) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeBuildingCnt = beforeBuildingCnt;
this.beforeContainerCnt = beforeContainerCnt;
}
public String getDataTypeName(String groupTitleCd) {
LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd);
return type == null ? null : type.getText();
}
public String getYear() {
return this.compareYyyy + "-" + this.targetYyyy;
}
public String getBeforeYear() {
if (this.beforeCompareYyyy == null || this.beforeTargetYyyy == null) {
return null;
}
return this.beforeCompareYyyy + "-" + this.beforeTargetYyyy;
}
}
@Getter @Getter
@Setter @Setter
@NoArgsConstructor @NoArgsConstructor

View File

@@ -134,26 +134,26 @@ public class ModelTrainDetailApiController {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
} }
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API") // @Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
@ApiResponses( // @ApiResponses(
value = { // value = {
@ApiResponse( // @ApiResponse(
responseCode = "200", // responseCode = "200",
description = "조회 성공", // description = "조회 성공",
content = // content =
@Content( // @Content(
mediaType = "application/json", // mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))), // schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), // @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) // @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) // })
@GetMapping("/transfer/detail/{uuid}") // @GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail( // public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e") // @Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable // @PathVariable
UUID uuid) { // UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid)); // return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
} // }
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API") @Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses( @ApiResponses(

View File

@@ -20,4 +20,25 @@ public class ModelConfigDto {
private Float testPercent; private Float testPercent;
private String memo; private String memo;
} }
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferBasic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
private Long beforeConfigId;
private Long beforeModelId;
private Integer beforeEpochCount;
private Float beforeTrainPercent;
private Float beforeValidationPercent;
private Float beforeTestPercent;
private String beforeMemo;
}
} }

View File

@@ -6,7 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums; import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration; import java.time.Duration;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
@@ -176,10 +176,10 @@ public class ModelTrainDetailDto {
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public static class TransferDetailDto { public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig; private ModelConfigDto.TransferBasic etcConfig;
private TransferHyperSummary modelTrainHyper; private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset; private List<SelectTransferDataSet> modelTrainDataset;
private List<SelectDataSet> beforeTrainDataset; // private List<SelectDataSet> beforeTrainDataset;
} }
@Getter @Getter

View File

@@ -251,6 +251,7 @@ public class ModelTrainMngDto {
private String memo; private String memo;
private String userNm; private String userNm;
private UUID beforeUuid;
public String getStatusName() { public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null; if (this.statusCd == null || this.statusCd.isBlank()) return null;

View File

@@ -2,7 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
@@ -73,11 +73,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid); Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회 // config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid); ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
// 하이파라미터 정보 조회 // 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid); TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>(); List<SelectTransferDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq(); DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>(); List<Long> datasetIds = new ArrayList<>();
@@ -91,39 +91,40 @@ public class ModelTrainDetailService {
datasetReq.setModelNo(modelInfo.getModelNo()); datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) { if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq); dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
} else { } else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq); dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
} }
DatasetReq beforeDatasetReq = new DatasetReq(); // DatasetReq beforeDatasetReq = new DatasetReq();
List<Long> beforeDatasetIds = new ArrayList<>(); // List<Long> beforeDatasetIds = new ArrayList<>();
List<SelectDataSet> beforeDataSets = new ArrayList<>(); // List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
Long beforeModelId = modelInfo.getBeforeModelId(); // Long beforeModelId = modelInfo.getBeforeModelId();
if (beforeModelId != null) { // if (beforeModelId != null) {
Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId); // Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
List<MappingDataset> beforeDatasets = // List<MappingDataset> beforeDatasets =
modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid()); // modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
for (MappingDataset before : beforeDatasets) { // for (MappingDataset before : beforeDatasets) {
beforeDatasetIds.add(before.getDatasetId()); // beforeDatasetIds.add(before.getDatasetId());
} // }
beforeDatasetReq.setIds(beforeDatasetIds); // beforeDatasetReq.setIds(beforeDatasetIds);
beforeDatasetReq.setModelNo(modelInfo.getModelNo()); // beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) { // if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq); // beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
} else { // } else {
beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq); // beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
} // }
} // }
TransferDetailDto transferDetailDto = new TransferDetailDto(); TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo); transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary); transferDetailDto.setModelTrainHyper(hyperSummary);
transferDetailDto.setModelTrainDataset(dataSets); transferDetailDto.setModelTrainDataset(dataSets);
transferDetailDto.setBeforeTrainDataset(beforeDataSets);
return transferDetailDto; return transferDetailDto;
} }

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType; import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -115,7 +116,7 @@ public class ModelTrainMngService {
* @return * @return
*/ */
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) { public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) { if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req); return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else { } else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req); return modelTrainMngCoreService.getDatasetSelectG2G3List(req);

View File

@@ -8,6 +8,7 @@ import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
@@ -271,6 +272,13 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} }
public ModelConfigDto.TransferBasic findModelTransferConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelTransferConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/** /**
* 데이터셋 G1 목록 * 데이터셋 G1 목록
* *
@@ -281,6 +289,16 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG1List(req); return datasetRepository.getDatasetSelectG1List(req);
} }
/**
* 전이학습 데이터셋 G1 목록
*
* @param modelId 모델 Id
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
return datasetRepository.getDatasetTransferSelectG1List(modelId);
}
/** /**
* 데이터셋 G2, G3 목록 * 데이터셋 G2, G3 목록
* *
@@ -291,6 +309,18 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG2G3List(req); return datasetRepository.getDatasetSelectG2G3List(req);
} }
/**
* 전이학습 데이터셋 G2, G3 목록
*
* @param modelId 모델 Id
* @param modelNo G2, G3
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
}
/** /**
* 모델관리 조회 * 모델관리 조회
* *

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@@ -17,6 +18,10 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req); List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId);
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(Long modelId, String modelNo);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req); List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy); Long getDatasetMaxStage(int compareYyyy, int targetYyyy);

View File

@@ -1,14 +1,20 @@
package com.kamco.cd.training.postgres.repository.dataset; package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity; import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq; import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity; import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.BooleanBuilder; import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder; import com.querydsl.core.types.dsl.CaseBuilder;
@@ -142,6 +148,103 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch(); .fetch();
} }
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum(),
// ===== before (join으로) =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum()))
.from(modelMasterEntity)
// ===== 현재 dataset join =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before 모델 join =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId))
.groupBy(
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override @Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) { public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
@@ -205,6 +308,116 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch(); .fetch();
} }
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
// before join용
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> elseCnt =
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> selectedCnt = ModelType.G2.getId().equals(modelNo) ? wasteCnt : elseCnt;
// before도 동일 로직으로 cnt 계산
NumberExpression<Long> beforeWasteCnt =
beforeObj.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> beforeElseCnt =
new CaseBuilder()
.when(beforeObj.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> beforeSelectedCnt =
ModelType.G2.getId().equals(modelNo) ? beforeWasteCnt : beforeElseCnt;
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo, // modelNo 파라미터 사용 (req.getModelNo() 제거)
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
selectedCnt, // classCount 자리에 들어가는 cnt (Long)
// ===== before =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
beforeSelectedCnt))
.from(modelMasterEntity)
// ===== 현재 dataset =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before dataset =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId).and(builder))
// sum() 때문에 groupBy 필요
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override @Override
public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) { public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) {
return queryFactory return queryFactory

View File

@@ -5,4 +5,6 @@ import java.util.Optional;
public interface ModelConfigRepositoryCustom { public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId); Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
Optional<ModelConfigDto.TransferBasic> findModelTransferConfigByModelId(Long modelId);
} }

View File

@@ -1,8 +1,12 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity; import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic; import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.kamco.cd.training.model.dto.ModelConfigDto.TransferBasic;
import com.kamco.cd.training.postgres.entity.QModelConfigEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional; import java.util.Optional;
@@ -34,4 +38,44 @@ public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
.where(modelConfigEntity.model.id.eq(modelId)) .where(modelConfigEntity.model.id.eq(modelId))
.fetchOne()); .fetchOne());
} }
@Override
public Optional<TransferBasic> findModelTransferConfigByModelId(Long modelId) {
QModelConfigEntity beforeConfig = new QModelConfigEntity("beforeConfig");
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
TransferBasic.class,
// ===== 현재 =====
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo,
// ===== before =====
beforeConfig.id,
beforeConfig.model.id,
beforeConfig.epochCount,
beforeConfig.trainPercent,
beforeConfig.validationPercent,
beforeConfig.testPercent,
beforeConfig.memo))
.from(modelConfigEntity)
.innerJoin(modelConfigEntity.model, modelMasterEntity)
// before 모델 조인
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeConfig)
.on(beforeConfig.model.id.eq(beforeMaster.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne());
}
} }

View File

@@ -9,8 +9,10 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder; import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions; import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -37,6 +39,13 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
*/ */
@Override @Override
public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) { public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
Pageable pageable = req.toPageable(); Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
@@ -78,7 +87,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelMasterEntity.packingStrtDttm, modelMasterEntity.packingStrtDttm,
modelMasterEntity.packingEndDttm, modelMasterEntity.packingEndDttm,
modelConfigEntity.memo, modelConfigEntity.memo,
memberEntity.name)) memberEntity.name,
beforeModelUuid))
.from(modelMasterEntity) .from(modelMasterEntity)
.innerJoin(modelConfigEntity) .innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id)) .on(modelMasterEntity.id.eq(modelConfigEntity.model.id))