데이터셋 등록 추가
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.common.dto.HyperParam;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.common.utils.UserUtil;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
@@ -21,10 +23,10 @@ public class HyperParamCoreService {
|
||||
/**
|
||||
* 하이퍼파라미터 등록
|
||||
*
|
||||
* @param createReq 등록 요청
|
||||
* @param createReq ModelTrainMngDto.HyperParamDto
|
||||
* @return 등록된 버전명
|
||||
*/
|
||||
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
|
||||
public Basic createHyperParam(HyperParam createReq) {
|
||||
String firstVersion = getFirstHyperParamVersion();
|
||||
|
||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||
@@ -36,7 +38,10 @@ public class HyperParamCoreService {
|
||||
entity.setCreatedUid(userUtil.getId());
|
||||
|
||||
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
|
||||
return resultEntity.getHyperVer();
|
||||
Basic basic = new Basic();
|
||||
basic.setUuid(resultEntity.getUuid());
|
||||
basic.setHyperVer(resultEntity.getHyperVer());
|
||||
return basic;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -46,7 +51,7 @@ public class HyperParamCoreService {
|
||||
* @param createReq 등록 요청
|
||||
* @return ver
|
||||
*/
|
||||
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
|
||||
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
@@ -61,47 +66,46 @@ public class HyperParamCoreService {
|
||||
return entity.getHyperVer();
|
||||
}
|
||||
|
||||
private void applyHyperParam(
|
||||
ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) {
|
||||
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
|
||||
// Important
|
||||
entity.setBackbone(createReq.getBackbone());
|
||||
entity.setInputSize(createReq.getInputSize());
|
||||
entity.setCropSize(createReq.getCropSize());
|
||||
entity.setBatchSize(createReq.getBatchSize());
|
||||
entity.setBackbone(src.getBackbone());
|
||||
entity.setInputSize(src.getInputSize());
|
||||
entity.setCropSize(src.getCropSize());
|
||||
entity.setBatchSize(src.getBatchSize());
|
||||
|
||||
// Data
|
||||
entity.setTrainNumWorkers(createReq.getTrainNumWorkers());
|
||||
entity.setValNumWorkers(createReq.getValNumWorkers());
|
||||
entity.setTestNumWorkers(createReq.getTestNumWorkers());
|
||||
entity.setTrainShuffle(createReq.getTrainShuffle());
|
||||
entity.setTrainPersistent(createReq.getTrainPersistent());
|
||||
entity.setValPersistent(createReq.getValPersistent());
|
||||
entity.setTrainNumWorkers(src.getTrainNumWorkers());
|
||||
entity.setValNumWorkers(src.getValNumWorkers());
|
||||
entity.setTestNumWorkers(src.getTestNumWorkers());
|
||||
entity.setTrainShuffle(src.getTrainShuffle());
|
||||
entity.setTrainPersistent(src.getTrainPersistent());
|
||||
entity.setValPersistent(src.getValPersistent());
|
||||
|
||||
// Model Architecture
|
||||
entity.setDropPathRate(createReq.getDropPathRate());
|
||||
entity.setFrozenStages(createReq.getFrozenStages());
|
||||
entity.setNeckPolicy(createReq.getNeckPolicy());
|
||||
entity.setClassWeight(createReq.getClassWeight());
|
||||
entity.setDecoderChannels(createReq.getDecoderChannels());
|
||||
entity.setDropPathRate(src.getDropPathRate());
|
||||
entity.setFrozenStages(src.getFrozenStages());
|
||||
entity.setNeckPolicy(src.getNeckPolicy());
|
||||
entity.setClassWeight(src.getClassWeight());
|
||||
entity.setDecoderChannels(src.getDecoderChannels());
|
||||
|
||||
// Loss & Optimization
|
||||
entity.setLearningRate(createReq.getLearningRate());
|
||||
entity.setWeightDecay(createReq.getWeightDecay());
|
||||
entity.setLayerDecayRate(createReq.getLayerDecayRate());
|
||||
entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams());
|
||||
entity.setIgnoreIndex(createReq.getIgnoreIndex());
|
||||
entity.setNumLayers(createReq.getNumLayers());
|
||||
entity.setLearningRate(src.getLearningRate());
|
||||
entity.setWeightDecay(src.getWeightDecay());
|
||||
entity.setLayerDecayRate(src.getLayerDecayRate());
|
||||
entity.setDdpFindUnusedParams(src.getDdpFindUnusedParams());
|
||||
entity.setIgnoreIndex(src.getIgnoreIndex());
|
||||
entity.setNumLayers(src.getNumLayers());
|
||||
|
||||
// Evaluation
|
||||
entity.setMetrics(createReq.getMetrics());
|
||||
entity.setSaveBest(createReq.getSaveBest());
|
||||
entity.setSaveBestRule(createReq.getSaveBestRule());
|
||||
entity.setValInterval(createReq.getValInterval());
|
||||
entity.setLogInterval(createReq.getLogInterval());
|
||||
entity.setVisInterval(createReq.getVisInterval());
|
||||
entity.setMetrics(src.getMetrics());
|
||||
entity.setSaveBest(src.getSaveBest());
|
||||
entity.setSaveBestRule(src.getSaveBestRule());
|
||||
entity.setValInterval(src.getValInterval());
|
||||
entity.setLogInterval(src.getLogInterval());
|
||||
entity.setVisInterval(src.getVisInterval());
|
||||
|
||||
// memo
|
||||
entity.setMemo(createReq.getMemo());
|
||||
entity.setMemo(src.getMemo());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.common.exception.BadRequestException;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.common.exception.NotFoundException;
|
||||
import com.kamco.cd.training.common.utils.UserUtil;
|
||||
import com.kamco.cd.training.model.dto.ModelMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDto.Basic;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset;
|
||||
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
@@ -21,9 +30,12 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelMngCoreService {
|
||||
public class ModelTrainMngCoreService {
|
||||
private final ModelMngRepository modelMngRepository;
|
||||
private final ModelDatasetMappRepository modelDatasetMappRepository;
|
||||
private final ModelDatasetRepository modelDatasetRepository;
|
||||
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||
private final ModelConfigRepository modelConfigRepository;
|
||||
private final HyperParamRepository hyperParamRepository;
|
||||
private final UserUtil userUtil;
|
||||
|
||||
/**
|
||||
@@ -32,7 +44,7 @@ public class ModelMngCoreService {
|
||||
* @param searchReq 검색 조건
|
||||
* @return 페이징 처리된 모델 목록
|
||||
*/
|
||||
public Page<Basic> findByModelList(ModelTrainDto.SearchReq searchReq) {
|
||||
public Page<Basic> findByModelList(ModelTrainMngDto.SearchReq searchReq) {
|
||||
Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
|
||||
return entityPage.map(ModelMasterEntity::toDto);
|
||||
}
|
||||
@@ -52,6 +64,103 @@ public class ModelMngCoreService {
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델학습 저장
|
||||
*
|
||||
* @param addReq
|
||||
* @return
|
||||
*/
|
||||
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
|
||||
ModelMasterEntity entity = new ModelMasterEntity();
|
||||
ModelHyperParamEntity hyperParamEntity =
|
||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||
|
||||
entity.setModelNo(addReq.getModelNo());
|
||||
entity.setTrainType(addReq.getTrainType());
|
||||
|
||||
if (hyperParamEntity != null) {
|
||||
entity.setHyperParamId(hyperParamEntity.getId());
|
||||
}
|
||||
|
||||
if (addReq.getIsStart()) {
|
||||
entity.setModelStep((short) 1);
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
}
|
||||
|
||||
entity.setCreatedUid(userUtil.getId());
|
||||
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
||||
return resultEntity.getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* data set 저장
|
||||
*
|
||||
* @param modelId 저장한 모델 학습 id
|
||||
* @param addReq 요청 파라미터
|
||||
*/
|
||||
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||
ModelMasterEntity modelMasterEntity = new ModelMasterEntity();
|
||||
ModelDatasetEntity datasetEntity = new ModelDatasetEntity();
|
||||
|
||||
modelMasterEntity.setId(modelId);
|
||||
datasetEntity.setModel(modelMasterEntity);
|
||||
|
||||
if (addReq.getModelNo().equals(ModelType.M1.getId())) {
|
||||
datasetEntity.setBuildingCnt(dataset.getSummary().getBuildingCnt());
|
||||
datasetEntity.setContainerCnt(dataset.getSummary().getContainerCnt());
|
||||
} else if (addReq.getModelNo().equals(ModelType.M2.getId())) {
|
||||
datasetEntity.setWasteCnt(dataset.getSummary().getWasteCnt());
|
||||
} else if (addReq.getModelNo().equals(ModelType.M3.getId())) {
|
||||
datasetEntity.setLandCoverCnt(dataset.getSummary().getLandCoverCnt());
|
||||
}
|
||||
|
||||
datasetEntity.setCreatedUid(userUtil.getId());
|
||||
|
||||
// data set 저장
|
||||
modelDatasetRepository.save(datasetEntity);
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 데이터셋 mapping 테이블 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param datasetList 선택한 data set
|
||||
*/
|
||||
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||
|
||||
for (Long datasetId : datasetList) {
|
||||
ModelDatasetMappEntity mapEntity = new ModelDatasetMappEntity();
|
||||
mapEntity.setModelUid(modelId);
|
||||
mapEntity.setDatasetUid(datasetId);
|
||||
modelDatasetMapRepository.save(mapEntity);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델학습 config 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param req 요청 파라미터
|
||||
* @return
|
||||
*/
|
||||
public Long saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||
ModelMasterEntity modelMasterEntity = new ModelMasterEntity();
|
||||
ModelConfigEntity entity = new ModelConfigEntity();
|
||||
modelMasterEntity.setId(modelId);
|
||||
entity.setModel(modelMasterEntity);
|
||||
entity.setEpochCount(req.getEpochCnt());
|
||||
entity.setTrainPercent(req.getTrainingCnt());
|
||||
entity.setValidationPercent(req.getValidationCnt());
|
||||
entity.setTestPercent(req.getTestCnt());
|
||||
entity.setMemo(req.getMemo());
|
||||
|
||||
return modelConfigRepository.save(entity).getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 상세 조회
|
||||
*
|
||||
@@ -136,7 +245,7 @@ public class ModelMngCoreService {
|
||||
mapping.setModelUid(modelUid);
|
||||
mapping.setDatasetUid(datasetId);
|
||||
mapping.setDatasetType("TRAIN");
|
||||
modelDatasetMappRepository.save(mapping);
|
||||
modelDatasetMapRepository.save(mapping);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ public class ModelConfigEntity {
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "config_id", nullable = false)
|
||||
private Integer id;
|
||||
private Long id;
|
||||
|
||||
@NotNull
|
||||
@ManyToOne(fetch = FetchType.LAZY, optional = false)
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.kamco.cd.training.postgres.entity;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.FetchType;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.JoinColumn;
|
||||
import jakarta.persistence.ManyToOne;
|
||||
@@ -20,6 +22,7 @@ import org.hibernate.annotations.ColumnDefault;
|
||||
public class ModelDatasetEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "id", nullable = false)
|
||||
private Long id;
|
||||
|
||||
@@ -28,10 +31,6 @@ public class ModelDatasetEntity {
|
||||
@JoinColumn(name = "model_id", nullable = false)
|
||||
private ModelMasterEntity model;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "data_id", nullable = false)
|
||||
private Long dataId;
|
||||
|
||||
@Column(name = "building_cnt")
|
||||
private Long buildingCnt;
|
||||
|
||||
@@ -46,7 +45,7 @@ public class ModelDatasetEntity {
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "created_dttm")
|
||||
private ZonedDateTime createdDttm;
|
||||
private ZonedDateTime createdDttm = ZonedDateTime.now();
|
||||
|
||||
@Column(name = "created_uid")
|
||||
private Long createdUid;
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.UUID;
|
||||
@@ -26,8 +25,7 @@ public class ModelMasterEntity {
|
||||
@Column(name = "model_id", nullable = false)
|
||||
private Long id;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "hyper_param_id", nullable = false)
|
||||
@Column(name = "hyper_param_id")
|
||||
private Long hyperParamId;
|
||||
|
||||
@Size(max = 10)
|
||||
@@ -69,7 +67,7 @@ public class ModelMasterEntity {
|
||||
private String step2State;
|
||||
|
||||
@Column(name = "del_yn")
|
||||
private Boolean delYn;
|
||||
private Boolean delYn = false;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "created_dttm")
|
||||
@@ -90,8 +88,8 @@ public class ModelMasterEntity {
|
||||
@Column(name = "train_type")
|
||||
private String trainType;
|
||||
|
||||
public ModelTrainDto.Basic toDto() {
|
||||
return new ModelTrainDto.Basic(
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
this.uuid,
|
||||
this.modelVer,
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
public interface ModelConfigRepository extends JpaRepository<ModelConfigEntity, Long> {}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
public interface ModelDatasetRepository
|
||||
extends JpaRepository<ModelDatasetEntity, Long>, ModelDatasetRepositoryCustom {}
|
||||
@@ -0,0 +1,3 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
public interface ModelDatasetRepositoryCustom {}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
@@ -14,7 +14,7 @@ public interface ModelMngRepositoryCustom {
|
||||
* @param searchReq
|
||||
* @return
|
||||
*/
|
||||
Page<ModelMasterEntity> findByModels(ModelTrainDto.SearchReq searchReq);
|
||||
Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq searchReq);
|
||||
|
||||
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
@@ -28,7 +28,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Page<ModelMasterEntity> findByModels(ModelTrainDto.SearchReq req) {
|
||||
public Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq req) {
|
||||
Pageable pageable = req.toPageable();
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user