전이학습 추가

This commit is contained in:
2026-02-05 18:23:07 +09:00
parent db6844f0e7
commit 0a7f01a2f5
13 changed files with 243 additions and 47 deletions

View File

@@ -3,11 +3,13 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository;
import java.util.List;
import java.util.UUID;
@@ -18,6 +20,7 @@ import org.springframework.stereotype.Service;
@RequiredArgsConstructor
public class ModelTrainDetailCoreService {
private final ModelDetailRepository modelDetailRepository;
private final ModelConfigRepository modelConfigRepository;
private final UserUtil userUtil;
/**
@@ -59,4 +62,14 @@ public class ModelTrainDetailCoreService {
ModelMasterEntity entity = modelDetailRepository.findByModelByUUID(uuid);
return entity.toDto();
}
/**
* 모델 학습별 config 정보 조회
*
* @param modelId
* @return
*/
public ModelConfigDto.Basic findModelConfig(Long modelId) {
return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null);
}
}

View File

@@ -6,6 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
@@ -225,24 +226,22 @@ public class ModelTrainMngCoreService {
}
/**
* 데이터셋 M1 목록
* 데이터셋 G1 목록
*
* @param modelType
* @param selectType
* @param req
* @return
*/
public List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType) {
return datasetRepository.getDatasetSelectM1List(modelType, selectType);
public List<SelectDataSet> getDatasetSelectG1List(DatasetReq req) {
return datasetRepository.getDatasetSelectG1List(req);
}
/**
* 데이터셋 M2, M3 목록
* 데이터셋 G2, G3 목록
*
* @param modelType
* @param selectType
* @param req
* @return
*/
public List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType) {
return datasetRepository.getDatasetSelectM2M3List(modelType, selectType);
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectG2G3List(req);
}
}

View File

@@ -101,6 +101,7 @@ public class ModelMasterEntity {
this.step1State,
this.step2State,
this.statusCd,
this.trainType);
this.trainType,
this.modelNo);
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.repository.dataset;
import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List;
@@ -13,7 +14,7 @@ public interface DatasetRepositoryCustom {
Optional<DatasetEntity> findByUuid(UUID id);
List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType);
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
}

View File

@@ -2,7 +2,9 @@ package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity;
@@ -35,7 +37,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
* @return 페이징 처리된 데이터셋 Entity 목록
*/
@Override
public Page<DatasetEntity> findDatasetList(DatasetDto.SearchReq searchReq) {
public Page<DatasetEntity> findDatasetList(SearchReq searchReq) {
Pageable pageable = searchReq.toPageable();
BooleanBuilder builder = new BooleanBuilder();
@@ -80,12 +82,20 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
}
@Override
public List<SelectDataSet> getDatasetSelectM1List(String modelType, String selectType) {
public List<SelectDataSet> getDatasetSelectG1List(DatasetReq req) {
BooleanBuilder builder = new BooleanBuilder();
if (StringUtils.isNotBlank(selectType) && !"CURRENT".equals(selectType)) {
builder.and(dataset.dataType.eq(selectType));
if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds()));
}
return queryFactory
@@ -126,11 +136,11 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
}
@Override
public List<SelectDataSet> getDatasetSelectM2M3List(String modelType, String selectType) {
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> selectedCnt;
NumberExpression<Long> selectedCnt = null;
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
@@ -141,14 +151,22 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.otherwise(0L)
.sum();
if (modelType.equals("M2")) {
selectedCnt = wasteCnt;
} else {
selectedCnt = elseCnt;
if (StringUtils.isNotBlank(req.getModelNo())) {
if (req.getModelNo().equals(ModelType.G2.getId())) {
selectedCnt = wasteCnt;
} else {
selectedCnt = elseCnt;
}
}
if (StringUtils.isNotBlank(selectType) && !"CURRENT".equals(selectType)) {
builder.and(dataset.dataType.eq(selectType));
if (StringUtils.isNotBlank(req.getDataType())) {
if (!"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
}
if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds()));
}
return queryFactory

View File

@@ -71,6 +71,7 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.backbone,
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize))
.from(modelHyperParamEntity)
.where(
@@ -88,6 +89,7 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
Projections.constructor(
MappingDataset.class,
modelMasterEntity.id,
datasetEntity.id,
datasetEntity.dataType,
datasetEntity.compareYyyy,
datasetEntity.targetYyyy,