Merge pull request 'feat/training_260202' (#30) from feat/training_260202 into develop

Reviewed-on: #30
This commit was merged in pull request #30.
This commit is contained in:
2026-02-11 14:08:58 +09:00
9 changed files with 96 additions and 2 deletions

View File

@@ -93,6 +93,29 @@ public class ModelTrainDetailDto {
private Integer batchSize;
}
@Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class TransferHyperSummary {
private UUID uuid;
private Long hyperParamId;
private String hyperVer;
private String backbone;
private String inputSize;
private String cropSize;
private Integer batchSize;
private UUID beforeUuid;
private Long beforeHyperParamId;
private String beforeHyperVer;
private String beforeBackbone;
private String beforeInputSize;
private String beforeCropSize;
private Integer beforeBatchSize;
}
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
@Getter
@Setter
@@ -154,7 +177,7 @@ public class ModelTrainDetailDto {
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private HyperSummary modelTrainHyper;
private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
}
}

View File

@@ -137,6 +137,9 @@ public class ModelTrainMngDto {
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
private String trainType;
@Schema(description = "전이학습일때 선택한 모델 id")
private Long beforeModelId;
@NotNull
@Schema(
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",

View File

@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
@@ -55,6 +56,12 @@ public class ModelTrainDetailService {
return modelTrainDetailCoreService.findByModelByUUID(uuid);
}
/**
* 전이학습 모델선택 정보
*
* @param uuid
* @return
*/
public TransferDetailDto getTransferDetail(UUID uuid) {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
@@ -62,7 +69,7 @@ public class ModelTrainDetailService {
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
@@ -74,6 +81,7 @@ public class ModelTrainDetailService {
datasetIds.add(mappingDataset.getDatasetId());
}
datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);

View File

@@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
@@ -15,6 +17,7 @@ import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -58,6 +61,13 @@ public class ModelTrainMngService {
HyperParam hyperParam = req.getHyperParam();
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
// 전이 학습은 모델 선택 필수
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
if (req.getBeforeModelId() != null) {
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
}
}
// 하이파라미터 신규저장
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
// 하이퍼파라미터 등록

View File

@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
@@ -54,6 +55,10 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}
public TransferHyperSummary getTransferHyperSummary(UUID uuid) {
return modelDetailRepository.getByModelTransferHyperParamSummary(uuid);
}
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
return modelDetailRepository.getByModelMappingDataset(uuid);
}

View File

@@ -98,6 +98,7 @@ public class ModelTrainMngCoreService {
entity.setHyperParamId(hyperParamEntity.getId());
entity.setModelNo(addReq.getModelNo());
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
entity.setBeforeModelId(addReq.getBeforeModelId());
if (addReq.getIsStart()) {
entity.setModelStep((short) 1);

View File

@@ -88,6 +88,9 @@ public class ModelMasterEntity {
@Column(name = "train_type")
private String trainType;
@Column(name = "before_model_id")
private Long beforeModelId;
public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic(
this.id,

View File

@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List;
import java.util.Optional;
@@ -16,6 +17,8 @@ public interface ModelDetailRepositoryCustom {
HyperSummary getByModelHyperParamSummary(UUID uuid);
TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid);
List<MappingDataset> getByModelMappingDataset(UUID uuid);
ModelMasterEntity findByModelByUUID(UUID uuid);

View File

@@ -9,7 +9,10 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -82,6 +85,41 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
.fetchOne();
}
@Override
public TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid) {
QModelMasterEntity subMaster = new QModelMasterEntity("subMaster");
QModelHyperParamEntity subHyper = new QModelHyperParamEntity("subHyper");
return queryFactory
.select(
Projections.constructor(
TransferHyperSummary.class,
modelHyperParamEntity.uuid,
modelHyperParamEntity.id,
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.backbone,
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
subHyper.uuid,
subHyper.id,
subHyper.hyperVer,
subHyper.backbone,
subHyper.inputSize,
subHyper.cropSize,
subHyper.batchSize))
.from(modelMasterEntity)
.innerJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(subMaster)
.on(subMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(subHyper)
.on(subHyper.id.eq(subMaster.hyperParamId))
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
return queryFactory