Merge pull request 'feat/training_260202' (#30) from feat/training_260202 into develop
Reviewed-on: #30
This commit was merged in pull request #30.
This commit is contained in:
@@ -93,6 +93,29 @@ public class ModelTrainDetailDto {
|
||||
private Integer batchSize;
|
||||
}
|
||||
|
||||
@Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터")
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public static class TransferHyperSummary {
|
||||
private UUID uuid;
|
||||
private Long hyperParamId;
|
||||
private String hyperVer;
|
||||
private String backbone;
|
||||
private String inputSize;
|
||||
private String cropSize;
|
||||
private Integer batchSize;
|
||||
private UUID beforeUuid;
|
||||
private Long beforeHyperParamId;
|
||||
private String beforeHyperVer;
|
||||
private String beforeBackbone;
|
||||
private String beforeInputSize;
|
||||
private String beforeCropSize;
|
||||
private Integer beforeBatchSize;
|
||||
}
|
||||
|
||||
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
|
||||
@Getter
|
||||
@Setter
|
||||
@@ -154,7 +177,7 @@ public class ModelTrainDetailDto {
|
||||
@AllArgsConstructor
|
||||
public static class TransferDetailDto {
|
||||
private ModelConfigDto.Basic etcConfig;
|
||||
private HyperSummary modelTrainHyper;
|
||||
private TransferHyperSummary modelTrainHyper;
|
||||
private List<SelectDataSet> modelTrainDataset;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +137,9 @@ public class ModelTrainMngDto {
|
||||
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
|
||||
private String trainType;
|
||||
|
||||
@Schema(description = "전이학습일때 선택한 모델 id")
|
||||
private Long beforeModelId;
|
||||
|
||||
@NotNull
|
||||
@Schema(
|
||||
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
@@ -55,6 +56,12 @@ public class ModelTrainDetailService {
|
||||
return modelTrainDetailCoreService.findByModelByUUID(uuid);
|
||||
}
|
||||
|
||||
/**
|
||||
* 전이학습 모델선택 정보
|
||||
*
|
||||
* @param uuid
|
||||
* @return
|
||||
*/
|
||||
public TransferDetailDto getTransferDetail(UUID uuid) {
|
||||
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
|
||||
|
||||
@@ -62,7 +69,7 @@ public class ModelTrainDetailService {
|
||||
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
|
||||
|
||||
// 하이파라미터 정보 조회
|
||||
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
|
||||
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
|
||||
List<SelectDataSet> dataSets = new ArrayList<>();
|
||||
|
||||
DatasetReq datasetReq = new DatasetReq();
|
||||
@@ -74,6 +81,7 @@ public class ModelTrainDetailService {
|
||||
datasetIds.add(mappingDataset.getDatasetId());
|
||||
}
|
||||
datasetReq.setIds(datasetIds);
|
||||
datasetReq.setModelNo(modelInfo.getModelNo());
|
||||
|
||||
if (modelInfo.getModelNo().equals("G1")) {
|
||||
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
|
||||
|
||||
@@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service;
|
||||
|
||||
import com.kamco.cd.training.common.dto.HyperParam;
|
||||
import com.kamco.cd.training.common.enums.HyperParamSelectType;
|
||||
import com.kamco.cd.training.common.enums.TrainType;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
|
||||
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
@@ -15,6 +17,7 @@ import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@@ -58,6 +61,13 @@ public class ModelTrainMngService {
|
||||
HyperParam hyperParam = req.getHyperParam();
|
||||
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||
|
||||
// 전이 학습은 모델 선택 필수
|
||||
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
|
||||
if (req.getBeforeModelId() != null) {
|
||||
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
|
||||
}
|
||||
}
|
||||
|
||||
// 하이파라미터 신규저장
|
||||
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
|
||||
// 하이퍼파라미터 등록
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
@@ -54,6 +55,10 @@ public class ModelTrainDetailCoreService {
|
||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||
}
|
||||
|
||||
public TransferHyperSummary getTransferHyperSummary(UUID uuid) {
|
||||
return modelDetailRepository.getByModelTransferHyperParamSummary(uuid);
|
||||
}
|
||||
|
||||
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
||||
return modelDetailRepository.getByModelMappingDataset(uuid);
|
||||
}
|
||||
|
||||
@@ -98,6 +98,7 @@ public class ModelTrainMngCoreService {
|
||||
entity.setHyperParamId(hyperParamEntity.getId());
|
||||
entity.setModelNo(addReq.getModelNo());
|
||||
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
|
||||
entity.setBeforeModelId(addReq.getBeforeModelId());
|
||||
|
||||
if (addReq.getIsStart()) {
|
||||
entity.setModelStep((short) 1);
|
||||
|
||||
@@ -88,6 +88,9 @@ public class ModelMasterEntity {
|
||||
@Column(name = "train_type")
|
||||
private String trainType;
|
||||
|
||||
@Column(name = "before_model_id")
|
||||
private Long beforeModelId;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -16,6 +17,8 @@ public interface ModelDetailRepositoryCustom {
|
||||
|
||||
HyperSummary getByModelHyperParamSummary(UUID uuid);
|
||||
|
||||
TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid);
|
||||
|
||||
List<MappingDataset> getByModelMappingDataset(UUID uuid);
|
||||
|
||||
ModelMasterEntity findByModelByUUID(UUID uuid);
|
||||
|
||||
@@ -9,7 +9,10 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.JPAExpressions;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
@@ -82,6 +85,41 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid) {
|
||||
|
||||
QModelMasterEntity subMaster = new QModelMasterEntity("subMaster");
|
||||
QModelHyperParamEntity subHyper = new QModelHyperParamEntity("subHyper");
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
TransferHyperSummary.class,
|
||||
modelHyperParamEntity.uuid,
|
||||
modelHyperParamEntity.id,
|
||||
modelHyperParamEntity.hyperVer,
|
||||
modelHyperParamEntity.backbone,
|
||||
modelHyperParamEntity.inputSize,
|
||||
modelHyperParamEntity.cropSize,
|
||||
modelHyperParamEntity.batchSize,
|
||||
subHyper.uuid,
|
||||
subHyper.id,
|
||||
subHyper.hyperVer,
|
||||
subHyper.backbone,
|
||||
subHyper.inputSize,
|
||||
subHyper.cropSize,
|
||||
subHyper.batchSize))
|
||||
.from(modelMasterEntity)
|
||||
.innerJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
.leftJoin(subMaster)
|
||||
.on(subMaster.id.eq(modelMasterEntity.beforeModelId))
|
||||
.leftJoin(subHyper)
|
||||
.on(subHyper.id.eq(subMaster.hyperParamId))
|
||||
.where(modelMasterEntity.uuid.eq(uuid))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
||||
return queryFactory
|
||||
|
||||
Reference in New Issue
Block a user