feat/training_260202 #133

Merged
teddy merged 2 commits from feat/training_260202 into develop 2026-02-20 18:22:45 +09:00
17 changed files with 579 additions and 33 deletions

View File

@@ -1,7 +1,6 @@
package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType;
@@ -234,7 +233,6 @@ public class DatasetDto {
@Getter
@Setter
@NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입
@@ -317,6 +315,183 @@ public class DatasetDto {
}
}
@Schema(name = "SelectTransferDataSet", description = "전이학습 데이터셋 선택 리스트")
@Getter
@Setter
@NoArgsConstructor
public static class SelectTransferDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
private String title;
private Long roundNo;
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
private String beforeModelNo; // G1, G2, G3 모델 타입
private Long beforeDatasetId;
private UUID beforeUuid;
private String beforeDataType;
private String beforeTitle;
private Long beforeRoundNo;
private Integer beforeCompareYyyy;
private Integer beforeTargetYyyy;
private String beforeMemo;
@JsonIgnore private Long beforeClassCount;
private Integer beforeBuildingCnt;
private Integer beforeContainerCnt;
private String beforeDataTypeName;
private Long beforeWasteCnt;
private Long beforeLandCoverCnt;
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Long classCount,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Long beforeClassCount) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo != null && modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo != null && modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeClassCount = beforeClassCount;
if (beforeModelNo != null && beforeModelNo.equals(ModelType.G2.getId())) {
this.beforeWasteCnt = beforeClassCount;
} else if (beforeModelNo != null && beforeModelNo.equals(ModelType.G3.getId())) {
this.beforeLandCoverCnt = beforeClassCount;
}
}
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCnt,
Integer containerCnt,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Integer beforeBuildingCnt,
Integer beforeContainerCnt) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeBuildingCnt = beforeBuildingCnt;
this.beforeContainerCnt = beforeContainerCnt;
}
public String getDataTypeName(String groupTitleCd) {
LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd);
return type == null ? null : type.getText();
}
public String getYear() {
return this.compareYyyy + "-" + this.targetYyyy;
}
public String getBeforeYear() {
if (this.beforeCompareYyyy == null || this.beforeTargetYyyy == null) {
return null;
}
return this.beforeCompareYyyy + "-" + this.beforeTargetYyyy;
}
}
@Getter
@Setter
@NoArgsConstructor

View File

@@ -134,26 +134,26 @@ public class ModelTrainDetailApiController {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
}
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
}
// @Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
// @ApiResponses(
// value = {
// @ApiResponse(
// responseCode = "200",
// description = "조회 성공",
// content =
// @Content(
// mediaType = "application/json",
// schema = @Schema(implementation = TransferDetailDto.class))),
// @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
// @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
// })
// @GetMapping("/transfer/detail/{uuid}")
// public ApiResponseDto<TransferDetailDto> getTransferDetail(
// @Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
// @PathVariable
// UUID uuid) {
// return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
// }
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses(

View File

@@ -20,4 +20,25 @@ public class ModelConfigDto {
private Float testPercent;
private String memo;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferBasic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
private Long beforeConfigId;
private Long beforeModelId;
private Integer beforeEpochCount;
private Float beforeTrainPercent;
private Float beforeValidationPercent;
private Float beforeTestPercent;
private String beforeMemo;
}
}

View File

@@ -6,7 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration;
import java.time.ZonedDateTime;
@@ -176,9 +176,10 @@ public class ModelTrainDetailDto {
@NoArgsConstructor
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private ModelConfigDto.TransferBasic etcConfig;
private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
private List<SelectTransferDataSet> modelTrainDataset;
// private List<SelectDataSet> beforeTrainDataset;
}
@Getter

View File

@@ -47,6 +47,8 @@ public class ModelTrainMngDto {
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private Long beforeModelId;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
try {
@@ -249,6 +251,7 @@ public class ModelTrainMngDto {
private String memo;
private String userNm;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;

View File

@@ -2,7 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
@@ -73,11 +73,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
// 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
List<SelectTransferDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
@@ -91,11 +91,36 @@ public class ModelTrainDetailService {
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
}
// DatasetReq beforeDatasetReq = new DatasetReq();
// List<Long> beforeDatasetIds = new ArrayList<>();
// List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
// Long beforeModelId = modelInfo.getBeforeModelId();
// if (beforeModelId != null) {
// Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
// List<MappingDataset> beforeDatasets =
// modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
// for (MappingDataset before : beforeDatasets) {
// beforeDatasetIds.add(before.getDatasetId());
// }
// beforeDatasetReq.setIds(beforeDatasetIds);
// beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
// if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
// beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
// } else {
// beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
// }
// }
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -115,7 +116,7 @@ public class ModelTrainMngService {
* @return
*/
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) {
if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);

View File

@@ -107,4 +107,9 @@ public class ModelTrainDetailCoreService {
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelDetailRepository.findModelTrainProgressInfo(uuid);
}
public Basic findByModelBeforeId(Long beforeModelId) {
ModelMasterEntity entity = modelDetailRepository.findByModelBeforeId(beforeModelId);
return entity.toDto();
}
}

View File

@@ -8,6 +8,7 @@ import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
@@ -271,6 +272,13 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
public ModelConfigDto.TransferBasic findModelTransferConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelTransferConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/**
* 데이터셋 G1 목록
*
@@ -281,6 +289,16 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG1List(req);
}
/**
* 전이학습 데이터셋 G1 목록
*
* @param modelId 모델 Id
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
return datasetRepository.getDatasetTransferSelectG1List(modelId);
}
/**
* 데이터셋 G2, G3 목록
*
@@ -291,6 +309,18 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 전이학습 데이터셋 G2, G3 목록
*
* @param modelId 모델 Id
* @param modelNo G2, G3
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
}
/**
* 모델관리 조회
*

View File

@@ -140,6 +140,7 @@ public class ModelMasterEntity {
this.requestPath,
this.packingState,
this.packingStrtDttm,
this.packingEndDttm);
this.packingEndDttm,
this.beforeModelId);
}
}

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List;
import java.util.Optional;
@@ -17,6 +18,10 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId);
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(Long modelId, String modelNo);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);

View File

@@ -1,14 +1,20 @@
package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder;
@@ -142,6 +148,103 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum(),
// ===== before (join으로) =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum()))
.from(modelMasterEntity)
// ===== 현재 dataset join =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before 모델 join =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId))
.groupBy(
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
@@ -205,6 +308,116 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
// before join용
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> elseCnt =
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> selectedCnt = ModelType.G2.getId().equals(modelNo) ? wasteCnt : elseCnt;
// before도 동일 로직으로 cnt 계산
NumberExpression<Long> beforeWasteCnt =
beforeObj.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> beforeElseCnt =
new CaseBuilder()
.when(beforeObj.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> beforeSelectedCnt =
ModelType.G2.getId().equals(modelNo) ? beforeWasteCnt : beforeElseCnt;
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo, // modelNo 파라미터 사용 (req.getModelNo() 제거)
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
selectedCnt, // classCount 자리에 들어가는 cnt (Long)
// ===== before =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
beforeSelectedCnt))
.from(modelMasterEntity)
// ===== 현재 dataset =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before dataset =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId).and(builder))
// sum() 때문에 groupBy 필요
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) {
return queryFactory

View File

@@ -5,4 +5,6 @@ import java.util.Optional;
public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
Optional<ModelConfigDto.TransferBasic> findModelTransferConfigByModelId(Long modelId);
}

View File

@@ -1,8 +1,12 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.kamco.cd.training.model.dto.ModelConfigDto.TransferBasic;
import com.kamco.cd.training.postgres.entity.QModelConfigEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional;
@@ -34,4 +38,44 @@ public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
.where(modelConfigEntity.model.id.eq(modelId))
.fetchOne());
}
@Override
public Optional<TransferBasic> findModelTransferConfigByModelId(Long modelId) {
QModelConfigEntity beforeConfig = new QModelConfigEntity("beforeConfig");
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
TransferBasic.class,
// ===== 현재 =====
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo,
// ===== before =====
beforeConfig.id,
beforeConfig.model.id,
beforeConfig.epochCount,
beforeConfig.trainPercent,
beforeConfig.validationPercent,
beforeConfig.testPercent,
beforeConfig.memo))
.from(modelConfigEntity)
.innerJoin(modelConfigEntity.model, modelMasterEntity)
// before 모델 조인
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeConfig)
.on(beforeConfig.model.id.eq(beforeMaster.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne());
}
}

View File

@@ -40,4 +40,6 @@ public interface ModelDetailRepositoryCustom {
ModelFileInfo getModelTrainFileInfo(UUID uuid);
List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid);
ModelMasterEntity findByModelBeforeId(Long beforeModelId);
}

View File

@@ -355,4 +355,12 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
return steps;
}
@Override
public ModelMasterEntity findByModelBeforeId(Long beforeModelId) {
return queryFactory
.selectFrom(modelMasterEntity)
.where(modelMasterEntity.id.eq(beforeModelId))
.fetchOne();
}
}

View File

@@ -9,8 +9,10 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -37,6 +39,13 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
*/
@Override
public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
@@ -78,7 +87,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelMasterEntity.packingStrtDttm,
modelMasterEntity.packingEndDttm,
modelConfigEntity.memo,
memberEntity.name))
memberEntity.name,
beforeModelUuid))
.from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))