전이학습 추가 #21

Merged
teddy merged 1 commits from feat/training_260202 into develop 2026-02-05 18:23:31 +09:00
13 changed files with 243 additions and 47 deletions
Showing only changes of commit 0a7f01a2f5 - Show all commits

View File

@@ -232,7 +232,7 @@ public class DatasetDto {
private Long datasetId; private Long datasetId;
private UUID uuid; private UUID uuid;
private String dataType; private String dataType;
private String yyyy; private String title;
private Long roundNo; private Long roundNo;
private Integer compareYyyy; private Integer compareYyyy;
private Integer targetYyyy; private Integer targetYyyy;
@@ -246,7 +246,7 @@ public class DatasetDto {
Long datasetId, Long datasetId,
UUID uuid, UUID uuid,
String dataType, String dataType,
String yyyy, String title,
Long roundNo, Long roundNo,
Integer compareYyyy, Integer compareYyyy,
Integer targetYyyy, Integer targetYyyy,
@@ -255,7 +255,7 @@ public class DatasetDto {
this.datasetId = datasetId; this.datasetId = datasetId;
this.uuid = uuid; this.uuid = uuid;
this.dataType = dataType; this.dataType = dataType;
this.yyyy = yyyy; this.title = title;
this.roundNo = roundNo; this.roundNo = roundNo;
this.compareYyyy = compareYyyy; this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy; this.targetYyyy = targetYyyy;
@@ -267,7 +267,7 @@ public class DatasetDto {
Long datasetId, Long datasetId,
UUID uuid, UUID uuid,
String dataType, String dataType,
String yyyy, String title,
Long roundNo, Long roundNo,
Integer compareYyyy, Integer compareYyyy,
Integer targetYyyy, Integer targetYyyy,
@@ -278,7 +278,7 @@ public class DatasetDto {
this.uuid = uuid; this.uuid = uuid;
this.dataType = dataType; this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType); this.dataTypeName = getDataTypeName(dataType);
this.yyyy = yyyy; this.title = title;
this.roundNo = roundNo; this.roundNo = roundNo;
this.compareYyyy = compareYyyy; this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy; this.targetYyyy = targetYyyy;
@@ -296,4 +296,16 @@ public class DatasetDto {
return this.compareYyyy + "-" + this.targetYyyy; return this.compareYyyy + "-" + this.targetYyyy;
} }
} }
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class DatasetReq {
String modelNo;
String dataType;
UUID uuid;
Long id;
List<Long> ids;
}
} }

View File

@@ -3,9 +3,16 @@ package com.kamco.cd.training.model;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.service.ModelTrainDetailService; import com.kamco.cd.training.model.service.ModelTrainDetailService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
@@ -22,29 +29,107 @@ import org.springframework.web.bind.annotation.RestController;
public class ModelTrainDetailApiController { public class ModelTrainDetailApiController {
private final ModelTrainDetailService modelTrainDetailService; private final ModelTrainDetailService modelTrainDetailService;
@Operation(summary = "모델학습 상세 조회", description = "모델학습 상세 조회 API") @Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 학습 진행정보", description = "학습 진행정보, 모델학습 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Basic.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/detail/{uuid}") @GetMapping("/detail/{uuid}")
public ApiResponseDto<Basic> findByModelByUUID(@PathVariable UUID uuid) { public ApiResponseDto<Basic> findByModelByUUID(
@Parameter(description = "모델학습 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.findByModelByUUID(uuid)); return ApiResponseDto.ok(modelTrainDetailService.findByModelByUUID(uuid));
} }
@Operation(summary = "모델학습 상세 요약 정보", description = "모델학습 상세 요약 정보 API") @Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 요약정보", description = "상세정보 탭 요약정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelTrainDetailDto.DetailSummary.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/summary/{uuid}") @GetMapping("/summary/{uuid}")
public ApiResponseDto<ModelTrainDetailDto.DetailSummary> getModelDetailSummary( public ApiResponseDto<ModelTrainDetailDto.DetailSummary> getModelDetailSummary(
@PathVariable UUID uuid) { @Parameter(description = "모델학습 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelDetailSummary(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getModelDetailSummary(uuid));
} }
@Operation(summary = "모델학습 상세 > 하이퍼파라미터 요약 정보", description = "모델학습 상세 하이퍼파라미터 요약 정보 API") @Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 하이퍼파라미터 요약 정보", description = "하이퍼파라미터 요약 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelTrainDetailDto.HyperSummary.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/hyper-summary/{uuid}") @GetMapping("/hyper-summary/{uuid}")
public ApiResponseDto<ModelTrainDetailDto.HyperSummary> getByModelHyperParamSummary( public ApiResponseDto<ModelTrainDetailDto.HyperSummary> getByModelHyperParamSummary(
@PathVariable UUID uuid) { @Parameter(description = "모델학습 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getByModelHyperParamSummary(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getByModelHyperParamSummary(uuid));
} }
@Operation(summary = "모델학습 상세 > 데이터셋 정보", description = "모델학습 상세 데이터셋 정보 API") @Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 데이터셋 정보", description = "모델학습 상세 데이터셋 정보 API")
@ApiResponses({
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
array = @ArraySchema(schema = @Schema(implementation = MappingDataset.class)))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음"),
@ApiResponse(responseCode = "500", description = "서버 오류")
})
@GetMapping("/mapp-dataset/{uuid}") @GetMapping("/mapp-dataset/{uuid}")
public ApiResponseDto<List<MappingDataset>> getByModelMappingDataset(@PathVariable UUID uuid) { public ApiResponseDto<List<MappingDataset>> getByModelMappingDataset(
@Parameter(description = "모델학습 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
} }
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
}
} }

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.dataset.dto.DatasetDto; import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
@@ -110,7 +111,10 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("/config/{uuid}") @GetMapping("/config/{uuid}")
public ApiResponseDto<ModelConfigDto.Basic> updateModelTrain(@PathVariable UUID uuid) { public ApiResponseDto<ModelConfigDto.Basic> updateModelTrain(
@Parameter(description = "모델학습 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainMngService.getModelConfigByModelId(uuid)); return ApiResponseDto.ok(modelTrainMngService.getModelConfigByModelId(uuid));
} }
@@ -141,6 +145,9 @@ public class ModelTrainMngApiController {
schema = @Schema(allowableValues = {"CURRENT", "DELIVER", "PRODUCTION"})) schema = @Schema(allowableValues = {"CURRENT", "DELIVER", "PRODUCTION"}))
@RequestParam @RequestParam
String selectType) { String selectType) {
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(modelType, selectType)); DatasetReq req = new DatasetReq();
req.setModelNo(modelType);
req.setDataType(selectType);
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
} }
} }

View File

@@ -6,9 +6,11 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums; import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration; import java.time.Duration;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
@@ -87,6 +89,7 @@ public class ModelTrainDetailDto {
private String hyperVer; private String hyperVer;
private String backbone; private String backbone;
private String inputSize; private String inputSize;
private String cropSize;
private Integer batchSize; private Integer batchSize;
} }
@@ -96,6 +99,7 @@ public class ModelTrainDetailDto {
@NoArgsConstructor @NoArgsConstructor
public static class MappingDataset { public static class MappingDataset {
private Long modelId; private Long modelId;
private Long datasetId;
private String dataType; private String dataType;
private Integer compareYyyy; private Integer compareYyyy;
private Integer targetYyyy; private Integer targetYyyy;
@@ -116,6 +120,7 @@ public class ModelTrainDetailDto {
public MappingDataset( public MappingDataset(
Long modelId, Long modelId,
Long datasetId,
String dataType, String dataType,
Integer compareYyyy, Integer compareYyyy,
Integer targetYyyy, Integer targetYyyy,
@@ -125,6 +130,7 @@ public class ModelTrainDetailDto {
Long wasteCnt, Long wasteCnt,
Long landCoverCnt) { Long landCoverCnt) {
this.modelId = modelId; this.modelId = modelId;
this.datasetId = datasetId;
this.dataType = dataType; this.dataType = dataType;
this.compareYyyy = compareYyyy; this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy; this.targetYyyy = targetYyyy;
@@ -141,4 +147,14 @@ public class ModelTrainDetailDto {
return type == null ? null : type.getText(); return type == null ? null : type.getText();
} }
} }
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private HyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
}
} }

View File

@@ -39,6 +39,7 @@ public class ModelTrainMngDto {
private String step2Status; private String step2Status;
private String statusCd; private String statusCd;
private String trainType; private String trainType;
private String modelNo;
public String getStatusName() { public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null; if (this.statusCd == null || this.statusCd.isBlank()) return null;

View File

@@ -1,10 +1,16 @@
package com.kamco.cd.training.model.service; package com.kamco.cd.training.model.service;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService; import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -19,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional;
public class ModelTrainDetailService { public class ModelTrainDetailService {
private final ModelTrainDetailCoreService modelTrainDetailCoreService; private final ModelTrainDetailCoreService modelTrainDetailCoreService;
private final ModelTrainMngCoreService mngCoreService;
/** /**
* 모델 상세정보 요약 * 모델 상세정보 요약
@@ -47,4 +54,38 @@ public class ModelTrainDetailService {
public Basic findByModelByUUID(UUID uuid) { public Basic findByModelByUUID(UUID uuid) {
return modelTrainDetailCoreService.findByModelByUUID(uuid); return modelTrainDetailCoreService.findByModelByUUID(uuid);
} }
public TransferDetailDto getTransferDetail(UUID uuid) {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
List<MappingDataset> mappingDatasets =
modelTrainDetailCoreService.getByModelMappingDataset(uuid);
for (MappingDataset mappingDataset : mappingDatasets) {
datasetIds.add(mappingDataset.getDatasetId());
}
datasetReq.setIds(datasetIds);
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
}
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);
transferDetailDto.setModelTrainDataset(dataSets);
return transferDetailDto;
}
} }

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType; import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
@@ -91,15 +92,14 @@ public class ModelTrainMngService {
/** /**
* 모델별 데이터셋 목록 조회 * 모델별 데이터셋 목록 조회
* *
* @param modelType * @param req
* @param selectType
* @return * @return
*/ */
public List<SelectDataSet> getDatasetSelectList(String modelType, String selectType) { public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (modelType.equals("G1")) { if (req.getModelNo().equals("G1")) {
return modelTrainMngCoreService.getDatasetSelectM1List(modelType, selectType); return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else { } else {
return modelTrainMngCoreService.getDatasetSelectM2M3List(modelType, selectType); return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
} }
} }
} }

View File

@@ -3,11 +3,13 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.BadRequestException; import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException; import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository; import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
@@ -18,6 +20,7 @@ import org.springframework.stereotype.Service;
@RequiredArgsConstructor @RequiredArgsConstructor
public class ModelTrainDetailCoreService { public class ModelTrainDetailCoreService {
private final ModelDetailRepository modelDetailRepository; private final ModelDetailRepository modelDetailRepository;
private final ModelConfigRepository modelConfigRepository;
private final UserUtil userUtil; private final UserUtil userUtil;
/** /**
@@ -59,4 +62,14 @@ public class ModelTrainDetailCoreService {
ModelMasterEntity entity = modelDetailRepository.findByModelByUUID(uuid); ModelMasterEntity entity = modelDetailRepository.findByModelByUUID(uuid);
return entity.toDto(); return entity.toDto();
} }
/**
* 모델 학습별 config 정보 조회
*
* @param modelId
* @return
*/
public ModelConfigDto.Basic findModelConfig(Long modelId) {
return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null);
}
} }

View File

@@ -6,6 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.exception.BadRequestException; import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
@@ -225,24 +226,22 @@ public class ModelTrainMngCoreService {
} }
/** /**
* 데이터셋 M1 목록 * 데이터셋 G1 목록
* *
* @param modelType * @param req
* @param selectType
* @return * @return
*/ */
public List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType) { public List<SelectDataSet> getDatasetSelectG1List(DatasetReq req) {
return datasetRepository.getDatasetSelectM1List(modelType, selectType); return datasetRepository.getDatasetSelectG1List(req);
} }
/** /**
* 데이터셋 M2, M3 목록 * 데이터셋 G2, G3 목록
* *
* @param modelType * @param req
* @param selectType
* @return * @return
*/ */
public List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType) { public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectM2M3List(modelType, selectType); return datasetRepository.getDatasetSelectG2G3List(req);
} }
} }

View File

@@ -101,6 +101,7 @@ public class ModelMasterEntity {
this.step1State, this.step1State,
this.step2State, this.step2State,
this.statusCd, this.statusCd,
this.trainType); this.trainType,
this.modelNo);
} }
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.repository.dataset; package com.kamco.cd.training.postgres.repository.dataset;
import com.kamco.cd.training.dataset.dto.DatasetDto; import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List; import java.util.List;
@@ -13,7 +14,7 @@ public interface DatasetRepositoryCustom {
Optional<DatasetEntity> findByUuid(UUID id); Optional<DatasetEntity> findByUuid(UUID id);
List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType); List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType); List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
} }

View File

@@ -2,7 +2,9 @@ package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity; import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import com.kamco.cd.training.dataset.dto.DatasetDto; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity; import com.kamco.cd.training.postgres.entity.QDatasetEntity;
@@ -35,7 +37,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
* @return 페이징 처리된 데이터셋 Entity 목록 * @return 페이징 처리된 데이터셋 Entity 목록
*/ */
@Override @Override
public Page<DatasetEntity> findDatasetList(DatasetDto.SearchReq searchReq) { public Page<DatasetEntity> findDatasetList(SearchReq searchReq) {
Pageable pageable = searchReq.toPageable(); Pageable pageable = searchReq.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
@@ -80,12 +82,20 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
} }
@Override @Override
public List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType) { public List<SelectDataSet> getDatasetSelectG1List(DatasetReq req) {
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
if (StringUtils.isNotBlank(selectType) && !"CURRENT".equals(selectType)) { if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(selectType)); builder.and(dataset.dataType.eq(req.getDataType()));
}
if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds()));
} }
return queryFactory return queryFactory
@@ -126,11 +136,11 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
} }
@Override @Override
public List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType) { public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> selectedCnt; NumberExpression<Long> selectedCnt = null;
NumberExpression<Long> wasteCnt = NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum(); datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
@@ -141,14 +151,22 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.otherwise(0L) .otherwise(0L)
.sum(); .sum();
if (modelType.equals("M2")) { if (StringUtils.isNotBlank(req.getModelNo())) {
if (req.getModelNo().equals(ModelType.G2.getId())) {
selectedCnt = wasteCnt; selectedCnt = wasteCnt;
} else { } else {
selectedCnt = elseCnt; selectedCnt = elseCnt;
} }
}
if (StringUtils.isNotBlank(selectType) && !"CURRENT".equals(selectType)) { if (StringUtils.isNotBlank(req.getDataType())) {
builder.and(dataset.dataType.eq(selectType)); if (!"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
}
if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds()));
} }
return queryFactory return queryFactory

View File

@@ -71,6 +71,7 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelHyperParamEntity.hyperVer, modelHyperParamEntity.hyperVer,
modelHyperParamEntity.backbone, modelHyperParamEntity.backbone,
modelHyperParamEntity.inputSize, modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize)) modelHyperParamEntity.batchSize))
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where( .where(
@@ -88,6 +89,7 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
Projections.constructor( Projections.constructor(
MappingDataset.class, MappingDataset.class,
modelMasterEntity.id, modelMasterEntity.id,
datasetEntity.id,
datasetEntity.dataType, datasetEntity.dataType,
datasetEntity.compareYyyy, datasetEntity.compareYyyy,
datasetEntity.targetYyyy, datasetEntity.targetYyyy,