From 224ddae68b6960f8e90ae2996caaf28a2c59228b Mon Sep 17 00:00:00 2001 From: teddy Date: Wed, 11 Feb 2026 14:05:15 +0900 Subject: [PATCH] =?UTF-8?q?=EC=A0=84=EC=9D=B4=ED=95=99=EC=8A=B5=20?= =?UTF-8?q?=EC=83=81=EC=84=B8=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/dto/ModelTrainDetailDto.java | 25 +++++++++++- .../training/model/dto/ModelTrainMngDto.java | 3 ++ .../service/ModelTrainDetailService.java | 9 ++++- .../model/service/ModelTrainMngService.java | 10 +++++ .../core/ModelTrainDetailCoreService.java | 5 +++ .../core/ModelTrainMngCoreService.java | 1 + .../postgres/entity/ModelMasterEntity.java | 3 ++ .../model/ModelDetailRepositoryCustom.java | 3 ++ .../model/ModelDetailRepositoryImpl.java | 38 +++++++++++++++++++ 9 files changed, 95 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java index a83e46d..8a854d8 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDetailDto.java @@ -93,6 +93,29 @@ public class ModelTrainDetailDto { private Integer batchSize; } + @Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + @Builder + public static class TransferHyperSummary { + private UUID uuid; + private Long hyperParamId; + private String hyperVer; + private String backbone; + private String inputSize; + private String cropSize; + private Integer batchSize; + private UUID beforeUuid; + private Long beforeHyperParamId; + private String beforeHyperVer; + private String beforeBackbone; + private String beforeInputSize; + private String beforeCropSize; + private Integer beforeBatchSize; + } + @Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록") @Getter @Setter @@ -154,7 +177,7 @@ public class ModelTrainDetailDto { @AllArgsConstructor public static class TransferDetailDto { private ModelConfigDto.Basic etcConfig; - private HyperSummary modelTrainHyper; + private TransferHyperSummary modelTrainHyper; private List modelTrainDataset; } } diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java index 87a789e..36b20f0 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -137,6 +137,9 @@ public class ModelTrainMngDto { @Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL") private String trainType; + @Schema(description = "전이학습일때 선택한 모델 id") + private Long beforeModelId; + @NotNull @Schema( description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)", diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java index 7c66ee5..5d57b60 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java @@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; @@ -55,6 +56,12 @@ public class ModelTrainDetailService { return modelTrainDetailCoreService.findByModelByUUID(uuid); } + /** + * 전이학습 모델선택 정보 + * + * @param uuid + * @return + */ public TransferDetailDto getTransferDetail(UUID uuid) { Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid); @@ -62,7 +69,7 @@ public class ModelTrainDetailService { ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid); // 하이파라미터 정보 조회 - HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid); + TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid); List dataSets = new ArrayList<>(); DatasetReq datasetReq = new DatasetReq(); diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index 303b56b..5e22404 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service; import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.enums.HyperParamSelectType; +import com.kamco.cd.training.common.enums.TrainType; +import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.hyperparam.dto.HyperParamDto; @@ -15,6 +17,7 @@ import java.util.UUID; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.data.domain.Page; +import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -58,6 +61,13 @@ public class ModelTrainMngService { HyperParam hyperParam = req.getHyperParam(); HyperParamDto.Basic hyper = new HyperParamDto.Basic(); + // 전이 학습은 모델 선택 필수 + if (req.getTrainType().equals(TrainType.TRANSFER.getId())) { + if (req.getBeforeModelId() != null) { + throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요."); + } + } + // 하이파라미터 신규저장 if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) { // 하이퍼파라미터 등록 diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java index 38c57bc..bb38f57 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainDetailCoreService.java @@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository; @@ -54,6 +55,10 @@ public class ModelTrainDetailCoreService { return modelDetailRepository.getByModelHyperParamSummary(uuid); } + public TransferHyperSummary getTransferHyperSummary(UUID uuid) { + return modelDetailRepository.getByModelTransferHyperParamSummary(uuid); + } + public List getByModelMappingDataset(UUID uuid) { return modelDetailRepository.getByModelMappingDataset(uuid); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 67de38b..4ccb9e4 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -98,6 +98,7 @@ public class ModelTrainMngCoreService { entity.setHyperParamId(hyperParamEntity.getId()); entity.setModelNo(addReq.getModelNo()); entity.setTrainType(addReq.getTrainType()); // 일반, 전이 + entity.setBeforeModelId(addReq.getBeforeModelId()); if (addReq.getIsStart()) { entity.setModelStep((short) 1); diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java index bb8a294..a06380e 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java @@ -88,6 +88,9 @@ public class ModelMasterEntity { @Column(name = "train_type") private String trainType; + @Column(name = "before_model_id") + private Long beforeModelId; + public ModelTrainMngDto.Basic toDto() { return new ModelTrainMngDto.Basic( this.id, diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java index 775b23b..1af36a4 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryCustom.java @@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import java.util.List; import java.util.Optional; @@ -16,6 +17,8 @@ public interface ModelDetailRepositoryCustom { HyperSummary getByModelHyperParamSummary(UUID uuid); + TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid); + List getByModelMappingDataset(UUID uuid); ModelMasterEntity findByModelByUUID(UUID uuid); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java index 02548a6..66b350e 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDetailRepositoryImpl.java @@ -9,7 +9,10 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; +import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; +import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity; +import com.kamco.cd.training.postgres.entity.QModelMasterEntity; import com.querydsl.core.types.Projections; import com.querydsl.jpa.JPAExpressions; import com.querydsl.jpa.impl.JPAQueryFactory; @@ -82,6 +85,41 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom { .fetchOne(); } + @Override + public TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid) { + + QModelMasterEntity subMaster = new QModelMasterEntity("subMaster"); + QModelHyperParamEntity subHyper = new QModelHyperParamEntity("subHyper"); + + return queryFactory + .select( + Projections.constructor( + TransferHyperSummary.class, + modelHyperParamEntity.uuid, + modelHyperParamEntity.id, + modelHyperParamEntity.hyperVer, + modelHyperParamEntity.backbone, + modelHyperParamEntity.inputSize, + modelHyperParamEntity.cropSize, + modelHyperParamEntity.batchSize, + subHyper.uuid, + subHyper.id, + subHyper.hyperVer, + subHyper.backbone, + subHyper.inputSize, + subHyper.cropSize, + subHyper.batchSize)) + .from(modelMasterEntity) + .innerJoin(modelHyperParamEntity) + .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) + .leftJoin(subMaster) + .on(subMaster.id.eq(modelMasterEntity.beforeModelId)) + .leftJoin(subHyper) + .on(subHyper.id.eq(subMaster.hyperParamId)) + .where(modelMasterEntity.uuid.eq(uuid)) + .fetchOne(); + } + @Override public List getByModelMappingDataset(UUID uuid) { return queryFactory