Compare commits
3 Commits
885b72a0c6
...
2f63b9ddcd
| Author | SHA1 | Date | |
|---|---|---|---|
| 2f63b9ddcd | |||
| 92de48b55e | |||
| 224ddae68b |
@@ -93,6 +93,29 @@ public class ModelTrainDetailDto {
|
|||||||
private Integer batchSize;
|
private Integer batchSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Builder
|
||||||
|
public static class TransferHyperSummary {
|
||||||
|
private UUID uuid;
|
||||||
|
private Long hyperParamId;
|
||||||
|
private String hyperVer;
|
||||||
|
private String backbone;
|
||||||
|
private String inputSize;
|
||||||
|
private String cropSize;
|
||||||
|
private Integer batchSize;
|
||||||
|
private UUID beforeUuid;
|
||||||
|
private Long beforeHyperParamId;
|
||||||
|
private String beforeHyperVer;
|
||||||
|
private String beforeBackbone;
|
||||||
|
private String beforeInputSize;
|
||||||
|
private String beforeCropSize;
|
||||||
|
private Integer beforeBatchSize;
|
||||||
|
}
|
||||||
|
|
||||||
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
|
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
@@ -154,7 +177,7 @@ public class ModelTrainDetailDto {
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public static class TransferDetailDto {
|
public static class TransferDetailDto {
|
||||||
private ModelConfigDto.Basic etcConfig;
|
private ModelConfigDto.Basic etcConfig;
|
||||||
private HyperSummary modelTrainHyper;
|
private TransferHyperSummary modelTrainHyper;
|
||||||
private List<SelectDataSet> modelTrainDataset;
|
private List<SelectDataSet> modelTrainDataset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ public class ModelTrainMngDto {
|
|||||||
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
|
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
|
||||||
private String trainType;
|
private String trainType;
|
||||||
|
|
||||||
|
@Schema(description = "전이학습일때 선택한 모델 id")
|
||||||
|
private Long beforeModelId;
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
@Schema(
|
@Schema(
|
||||||
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
|
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
|||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
@@ -55,6 +56,12 @@ public class ModelTrainDetailService {
|
|||||||
return modelTrainDetailCoreService.findByModelByUUID(uuid);
|
return modelTrainDetailCoreService.findByModelByUUID(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 전이학습 모델선택 정보
|
||||||
|
*
|
||||||
|
* @param uuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public TransferDetailDto getTransferDetail(UUID uuid) {
|
public TransferDetailDto getTransferDetail(UUID uuid) {
|
||||||
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
|
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
|
||||||
|
|
||||||
@@ -62,7 +69,7 @@ public class ModelTrainDetailService {
|
|||||||
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
|
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
|
||||||
|
|
||||||
// 하이파라미터 정보 조회
|
// 하이파라미터 정보 조회
|
||||||
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
|
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
|
||||||
List<SelectDataSet> dataSets = new ArrayList<>();
|
List<SelectDataSet> dataSets = new ArrayList<>();
|
||||||
|
|
||||||
DatasetReq datasetReq = new DatasetReq();
|
DatasetReq datasetReq = new DatasetReq();
|
||||||
@@ -74,6 +81,7 @@ public class ModelTrainDetailService {
|
|||||||
datasetIds.add(mappingDataset.getDatasetId());
|
datasetIds.add(mappingDataset.getDatasetId());
|
||||||
}
|
}
|
||||||
datasetReq.setIds(datasetIds);
|
datasetReq.setIds(datasetIds);
|
||||||
|
datasetReq.setModelNo(modelInfo.getModelNo());
|
||||||
|
|
||||||
if (modelInfo.getModelNo().equals("G1")) {
|
if (modelInfo.getModelNo().equals("G1")) {
|
||||||
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
|
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service;
|
|||||||
|
|
||||||
import com.kamco.cd.training.common.dto.HyperParam;
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
import com.kamco.cd.training.common.enums.HyperParamSelectType;
|
import com.kamco.cd.training.common.enums.HyperParamSelectType;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainType;
|
||||||
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
|
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
|
||||||
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
|
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
@@ -15,6 +17,7 @@ import java.util.UUID;
|
|||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.data.domain.Page;
|
import org.springframework.data.domain.Page;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
@@ -58,6 +61,13 @@ public class ModelTrainMngService {
|
|||||||
HyperParam hyperParam = req.getHyperParam();
|
HyperParam hyperParam = req.getHyperParam();
|
||||||
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||||
|
|
||||||
|
// 전이 학습은 모델 선택 필수
|
||||||
|
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
|
||||||
|
if (req.getBeforeModelId() != null) {
|
||||||
|
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 하이파라미터 신규저장
|
// 하이파라미터 신규저장
|
||||||
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
|
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
|
||||||
// 하이퍼파라미터 등록
|
// 하이퍼파라미터 등록
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
|||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||||
@@ -54,6 +55,10 @@ public class ModelTrainDetailCoreService {
|
|||||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TransferHyperSummary getTransferHyperSummary(UUID uuid) {
|
||||||
|
return modelDetailRepository.getByModelTransferHyperParamSummary(uuid);
|
||||||
|
}
|
||||||
|
|
||||||
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
||||||
return modelDetailRepository.getByModelMappingDataset(uuid);
|
return modelDetailRepository.getByModelMappingDataset(uuid);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ public class ModelTrainMngCoreService {
|
|||||||
entity.setHyperParamId(hyperParamEntity.getId());
|
entity.setHyperParamId(hyperParamEntity.getId());
|
||||||
entity.setModelNo(addReq.getModelNo());
|
entity.setModelNo(addReq.getModelNo());
|
||||||
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
|
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
|
||||||
|
entity.setBeforeModelId(addReq.getBeforeModelId());
|
||||||
|
|
||||||
if (addReq.getIsStart()) {
|
if (addReq.getIsStart()) {
|
||||||
entity.setModelStep((short) 1);
|
entity.setModelStep((short) 1);
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ public class ModelMasterEntity {
|
|||||||
@Column(name = "train_type")
|
@Column(name = "train_type")
|
||||||
private String trainType;
|
private String trainType;
|
||||||
|
|
||||||
|
@Column(name = "before_model_id")
|
||||||
|
private Long beforeModelId;
|
||||||
|
|
||||||
public ModelTrainMngDto.Basic toDto() {
|
public ModelTrainMngDto.Basic toDto() {
|
||||||
return new ModelTrainMngDto.Basic(
|
return new ModelTrainMngDto.Basic(
|
||||||
this.id,
|
this.id,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
|||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
@@ -16,6 +17,8 @@ public interface ModelDetailRepositoryCustom {
|
|||||||
|
|
||||||
HyperSummary getByModelHyperParamSummary(UUID uuid);
|
HyperSummary getByModelHyperParamSummary(UUID uuid);
|
||||||
|
|
||||||
|
TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid);
|
||||||
|
|
||||||
List<MappingDataset> getByModelMappingDataset(UUID uuid);
|
List<MappingDataset> getByModelMappingDataset(UUID uuid);
|
||||||
|
|
||||||
ModelMasterEntity findByModelByUUID(UUID uuid);
|
ModelMasterEntity findByModelByUUID(UUID uuid);
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
|||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
|
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
|
||||||
|
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
|
||||||
import com.querydsl.core.types.Projections;
|
import com.querydsl.core.types.Projections;
|
||||||
import com.querydsl.jpa.JPAExpressions;
|
import com.querydsl.jpa.JPAExpressions;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
@@ -82,6 +85,41 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
|||||||
.fetchOne();
|
.fetchOne();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid) {
|
||||||
|
|
||||||
|
QModelMasterEntity subMaster = new QModelMasterEntity("subMaster");
|
||||||
|
QModelHyperParamEntity subHyper = new QModelHyperParamEntity("subHyper");
|
||||||
|
|
||||||
|
return queryFactory
|
||||||
|
.select(
|
||||||
|
Projections.constructor(
|
||||||
|
TransferHyperSummary.class,
|
||||||
|
modelHyperParamEntity.uuid,
|
||||||
|
modelHyperParamEntity.id,
|
||||||
|
modelHyperParamEntity.hyperVer,
|
||||||
|
modelHyperParamEntity.backbone,
|
||||||
|
modelHyperParamEntity.inputSize,
|
||||||
|
modelHyperParamEntity.cropSize,
|
||||||
|
modelHyperParamEntity.batchSize,
|
||||||
|
subHyper.uuid,
|
||||||
|
subHyper.id,
|
||||||
|
subHyper.hyperVer,
|
||||||
|
subHyper.backbone,
|
||||||
|
subHyper.inputSize,
|
||||||
|
subHyper.cropSize,
|
||||||
|
subHyper.batchSize))
|
||||||
|
.from(modelMasterEntity)
|
||||||
|
.innerJoin(modelHyperParamEntity)
|
||||||
|
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||||
|
.leftJoin(subMaster)
|
||||||
|
.on(subMaster.id.eq(modelMasterEntity.beforeModelId))
|
||||||
|
.leftJoin(subHyper)
|
||||||
|
.on(subHyper.id.eq(subMaster.hyperParamId))
|
||||||
|
.where(modelMasterEntity.uuid.eq(uuid))
|
||||||
|
.fetchOne();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
|
||||||
return queryFactory
|
return queryFactory
|
||||||
|
|||||||
Reference in New Issue
Block a user