전이학습 추가

This commit is contained in:
2026-02-05 18:23:07 +09:00
parent db6844f0e7
commit 0a7f01a2f5
13 changed files with 243 additions and 47 deletions

View File

@@ -1,10 +1,16 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -19,6 +25,7 @@ import org.springframework.transaction.annotation.Transactional;
public class ModelTrainDetailService {
private final ModelTrainDetailCoreService modelTrainDetailCoreService;
private final ModelTrainMngCoreService mngCoreService;
/**
* 모델 상세정보 요약
@@ -47,4 +54,38 @@ public class ModelTrainDetailService {
public Basic findByModelByUUID(UUID uuid) {
return modelTrainDetailCoreService.findByModelByUUID(uuid);
}
public TransferDetailDto getTransferDetail(UUID uuid) {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
List<MappingDataset> mappingDatasets =
modelTrainDetailCoreService.getByModelMappingDataset(uuid);
for (MappingDataset mappingDataset : mappingDatasets) {
datasetIds.add(mappingDataset.getDatasetId());
}
datasetReq.setIds(datasetIds);
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
}
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);
transferDetailDto.setModelTrainDataset(dataSets);
return transferDetailDto;
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.model.dto.ModelConfigDto;
@@ -91,15 +92,14 @@ public class ModelTrainMngService {
/**
* 모델별 데이터셋 목록 조회
*
* @param modelType
* @param selectType
* @param req
* @return
*/
public List<SelectDataSet> getDatasetSelectList(String modelType, String selectType) {
if (modelType.equals("G1")) {
return modelTrainMngCoreService.getDatasetSelectM1List(modelType, selectType);
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) {
return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else {
return modelTrainMngCoreService.getDatasetSelectM2M3List(modelType, selectType);
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
}
}
}