전이학습 상세 수정

This commit is contained in:
2026-02-11 14:05:15 +09:00
parent 9ac00d37c5
commit 224ddae68b
9 changed files with 95 additions and 2 deletions

View File

@@ -93,6 +93,29 @@ public class ModelTrainDetailDto {
private Integer batchSize;
}
@Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class TransferHyperSummary {
private UUID uuid;
private Long hyperParamId;
private String hyperVer;
private String backbone;
private String inputSize;
private String cropSize;
private Integer batchSize;
private UUID beforeUuid;
private Long beforeHyperParamId;
private String beforeHyperVer;
private String beforeBackbone;
private String beforeInputSize;
private String beforeCropSize;
private Integer beforeBatchSize;
}
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
@Getter
@Setter
@@ -154,7 +177,7 @@ public class ModelTrainDetailDto {
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private HyperSummary modelTrainHyper;
private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
}
}

View File

@@ -137,6 +137,9 @@ public class ModelTrainMngDto {
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
private String trainType;
@Schema(description = "전이학습일때 선택한 모델 id")
private Long beforeModelId;
@NotNull
@Schema(
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",

View File

@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
@@ -55,6 +56,12 @@ public class ModelTrainDetailService {
return modelTrainDetailCoreService.findByModelByUUID(uuid);
}
/**
* 전이학습 모델선택 정보
*
* @param uuid
* @return
*/
public TransferDetailDto getTransferDetail(UUID uuid) {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
@@ -62,7 +69,7 @@ public class ModelTrainDetailService {
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();

View File

@@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
@@ -15,6 +17,7 @@ import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -58,6 +61,13 @@ public class ModelTrainMngService {
HyperParam hyperParam = req.getHyperParam();
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
// 전이 학습은 모델 선택 필수
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
if (req.getBeforeModelId() != null) {
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
}
}
// 하이파라미터 신규저장
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
// 하이퍼파라미터 등록