전이학습 상세 수정

This commit is contained in:
2026-02-20 18:22:19 +09:00
parent 83859bb9fe
commit 07429dbe8e
13 changed files with 559 additions and 56 deletions

View File

@@ -2,7 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
@@ -73,11 +73,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
// 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
List<SelectTransferDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
@@ -91,39 +91,40 @@ public class ModelTrainDetailService {
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
}
DatasetReq beforeDatasetReq = new DatasetReq();
List<Long> beforeDatasetIds = new ArrayList<>();
List<SelectDataSet> beforeDataSets = new ArrayList<>();
Long beforeModelId = modelInfo.getBeforeModelId();
if (beforeModelId != null) {
Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
List<MappingDataset> beforeDatasets =
modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
for (MappingDataset before : beforeDatasets) {
beforeDatasetIds.add(before.getDatasetId());
}
beforeDatasetReq.setIds(beforeDatasetIds);
beforeDatasetReq.setModelNo(modelInfo.getModelNo());
if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
} else {
beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
}
}
// DatasetReq beforeDatasetReq = new DatasetReq();
// List<Long> beforeDatasetIds = new ArrayList<>();
// List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
// Long beforeModelId = modelInfo.getBeforeModelId();
// if (beforeModelId != null) {
// Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
// List<MappingDataset> beforeDatasets =
// modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
// for (MappingDataset before : beforeDatasets) {
// beforeDatasetIds.add(before.getDatasetId());
// }
// beforeDatasetReq.setIds(beforeDatasetIds);
// beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
// if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
// beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
// } else {
// beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
// }
// }
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);
transferDetailDto.setModelTrainDataset(dataSets);
transferDetailDto.setBeforeTrainDataset(beforeDataSets);
return transferDetailDto;
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -115,7 +116,7 @@ public class ModelTrainMngService {
* @return
*/
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) {
if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);