하이퍼파라미터 기능 추가
This commit is contained in:
@@ -1,31 +1,24 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.common.exception.BadRequestException;
|
||||
import com.kamco.cd.training.common.exception.NotFoundException;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.common.utils.UserUtil;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.model.dto.ModelMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelHyperParamRepository;
|
||||
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class HyperParamCoreService {
|
||||
private final ModelHyperParamRepository hyperParamRepository;
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 전체 조회 (삭제되지 않은 것만)
|
||||
*
|
||||
* @return 하이퍼파라미터 목록
|
||||
*/
|
||||
public List<ModelMngDto.HyperParamInfo> findAllActiveHyperParams() {
|
||||
List<ModelHyperParamEntity> entities =
|
||||
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
|
||||
|
||||
return entities.stream().map(this::mapToHyperParamInfo).toList();
|
||||
}
|
||||
private final HyperParamRepository hyperParamRepository;
|
||||
private final UserUtil userUtil;
|
||||
|
||||
private ModelMngDto.HyperParamInfo mapToHyperParamInfo(ModelHyperParamEntity entity) {
|
||||
return ModelMngDto.HyperParamInfo.builder()
|
||||
@@ -76,8 +69,6 @@ public class HyperParamCoreService {
|
||||
.contrastRange(entity.getContrastRange())
|
||||
.saturationRange(entity.getSaturationRange())
|
||||
.hueDelta(entity.getHueDelta())
|
||||
// Legacy
|
||||
.cnnFilterCnt(entity.getCnnFilterCnt())
|
||||
// Common
|
||||
.memo(entity.getMemo())
|
||||
.createdDttm(entity.getCreatedDttm())
|
||||
@@ -90,168 +81,104 @@ public class HyperParamCoreService {
|
||||
* @param createReq 등록 요청
|
||||
* @return 등록된 버전명
|
||||
*/
|
||||
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
|
||||
// 중복 체크
|
||||
if (hyperParamRepository.existsById(createReq.getNewHyperVer())) {
|
||||
throw new BadRequestException("이미 존재하는 버전입니다: " + createReq.getNewHyperVer());
|
||||
}
|
||||
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
|
||||
String firstVersion = getFirstHyperParamVersion();
|
||||
|
||||
// 기준 버전 조회
|
||||
ModelHyperParamEntity baseEntity =
|
||||
hyperParamRepository
|
||||
.findById(createReq.getBaseHyperVer())
|
||||
.orElseThrow(
|
||||
() -> new NotFoundException("기준 버전을 찾을 수 없습니다: " + createReq.getBaseHyperVer()));
|
||||
|
||||
// 신규 엔티티 생성 (기준 값 복사 후 변경된 값만 적용)
|
||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||
entity.setHyperVer(createReq.getNewHyperVer());
|
||||
entity.setHyperVer(firstVersion);
|
||||
|
||||
applyHyperParam(entity, createReq);
|
||||
|
||||
// user
|
||||
entity.setCreatedUid(userUtil.getId());
|
||||
|
||||
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
|
||||
return resultEntity.getHyperVer();
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 수정
|
||||
*
|
||||
* @param uuid uuid
|
||||
* @param createReq 등록 요청
|
||||
* @return ver
|
||||
*/
|
||||
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
applyHyperParam(entity, createReq);
|
||||
|
||||
// user
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
|
||||
return entity.getHyperVer();
|
||||
}
|
||||
|
||||
private void applyHyperParam(
|
||||
ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) {
|
||||
// Important
|
||||
entity.setBackbone(
|
||||
createReq.getBackbone() != null ? createReq.getBackbone() : baseEntity.getBackbone());
|
||||
entity.setInputSize(
|
||||
createReq.getInputSize() != null ? createReq.getInputSize() : baseEntity.getInputSize());
|
||||
entity.setCropSize(
|
||||
createReq.getCropSize() != null ? createReq.getCropSize() : baseEntity.getCropSize());
|
||||
entity.setEpochCnt(
|
||||
createReq.getEpochCnt() != null ? createReq.getEpochCnt() : baseEntity.getEpochCnt());
|
||||
entity.setBatchSize(
|
||||
createReq.getBatchSize() != null ? createReq.getBatchSize() : baseEntity.getBatchSize());
|
||||
|
||||
// Architecture
|
||||
entity.setDropPathRate(
|
||||
createReq.getDropPathRate() != null
|
||||
? createReq.getDropPathRate()
|
||||
: baseEntity.getDropPathRate());
|
||||
entity.setFrozenStages(
|
||||
createReq.getFrozenStages() != null
|
||||
? createReq.getFrozenStages()
|
||||
: baseEntity.getFrozenStages());
|
||||
entity.setNeckPolicy(
|
||||
createReq.getNeckPolicy() != null ? createReq.getNeckPolicy() : baseEntity.getNeckPolicy());
|
||||
entity.setDecoderChannels(
|
||||
createReq.getDecoderChannels() != null
|
||||
? createReq.getDecoderChannels()
|
||||
: baseEntity.getDecoderChannels());
|
||||
entity.setClassWeight(
|
||||
createReq.getClassWeight() != null
|
||||
? createReq.getClassWeight()
|
||||
: baseEntity.getClassWeight());
|
||||
entity.setNumLayers(
|
||||
createReq.getNumLayers() != null ? createReq.getNumLayers() : baseEntity.getNumLayers());
|
||||
|
||||
// Optimization
|
||||
entity.setLearningRate(
|
||||
createReq.getLearningRate() != null
|
||||
? createReq.getLearningRate()
|
||||
: baseEntity.getLearningRate());
|
||||
entity.setWeightDecay(
|
||||
createReq.getWeightDecay() != null
|
||||
? createReq.getWeightDecay()
|
||||
: baseEntity.getWeightDecay());
|
||||
entity.setLayerDecayRate(
|
||||
createReq.getLayerDecayRate() != null
|
||||
? createReq.getLayerDecayRate()
|
||||
: baseEntity.getLayerDecayRate());
|
||||
entity.setDdpFindUnusedParams(
|
||||
createReq.getDdpFindUnusedParams() != null
|
||||
? createReq.getDdpFindUnusedParams()
|
||||
: baseEntity.getDdpFindUnusedParams());
|
||||
entity.setIgnoreIndex(
|
||||
createReq.getIgnoreIndex() != null
|
||||
? createReq.getIgnoreIndex()
|
||||
: baseEntity.getIgnoreIndex());
|
||||
entity.setBackbone(createReq.getBackbone());
|
||||
entity.setInputSize(createReq.getInputSize());
|
||||
entity.setCropSize(createReq.getCropSize());
|
||||
entity.setBatchSize(createReq.getBatchSize());
|
||||
|
||||
// Data
|
||||
entity.setTrainNumWorkers(
|
||||
createReq.getTrainNumWorkers() != null
|
||||
? createReq.getTrainNumWorkers()
|
||||
: baseEntity.getTrainNumWorkers());
|
||||
entity.setValNumWorkers(
|
||||
createReq.getValNumWorkers() != null
|
||||
? createReq.getValNumWorkers()
|
||||
: baseEntity.getValNumWorkers());
|
||||
entity.setTestNumWorkers(
|
||||
createReq.getTestNumWorkers() != null
|
||||
? createReq.getTestNumWorkers()
|
||||
: baseEntity.getTestNumWorkers());
|
||||
entity.setTrainShuffle(
|
||||
createReq.getTrainShuffle() != null
|
||||
? createReq.getTrainShuffle()
|
||||
: baseEntity.getTrainShuffle());
|
||||
entity.setTrainPersistent(
|
||||
createReq.getTrainPersistent() != null
|
||||
? createReq.getTrainPersistent()
|
||||
: baseEntity.getTrainPersistent());
|
||||
entity.setValPersistent(
|
||||
createReq.getValPersistent() != null
|
||||
? createReq.getValPersistent()
|
||||
: baseEntity.getValPersistent());
|
||||
entity.setTrainNumWorkers(createReq.getTrainNumWorkers());
|
||||
entity.setValNumWorkers(createReq.getValNumWorkers());
|
||||
entity.setTestNumWorkers(createReq.getTestNumWorkers());
|
||||
entity.setTrainShuffle(createReq.getTrainShuffle());
|
||||
entity.setTrainPersistent(createReq.getTrainPersistent());
|
||||
entity.setValPersistent(createReq.getValPersistent());
|
||||
|
||||
// Model Architecture
|
||||
entity.setDropPathRate(createReq.getDropPathRate());
|
||||
entity.setFrozenStages(createReq.getFrozenStages());
|
||||
entity.setNeckPolicy(createReq.getNeckPolicy());
|
||||
entity.setClassWeight(createReq.getClassWeight());
|
||||
entity.setDecoderChannels(createReq.getDecoderChannels());
|
||||
|
||||
// Loss & Optimization
|
||||
entity.setLearningRate(createReq.getLearningRate());
|
||||
entity.setWeightDecay(createReq.getWeightDecay());
|
||||
entity.setLayerDecayRate(createReq.getLayerDecayRate());
|
||||
entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams());
|
||||
entity.setIgnoreIndex(createReq.getIgnoreIndex());
|
||||
entity.setNumLayers(createReq.getNumLayers());
|
||||
|
||||
// Evaluation
|
||||
entity.setMetrics(
|
||||
createReq.getMetrics() != null ? createReq.getMetrics() : baseEntity.getMetrics());
|
||||
entity.setSaveBest(
|
||||
createReq.getSaveBest() != null ? createReq.getSaveBest() : baseEntity.getSaveBest());
|
||||
entity.setSaveBestRule(
|
||||
createReq.getSaveBestRule() != null
|
||||
? createReq.getSaveBestRule()
|
||||
: baseEntity.getSaveBestRule());
|
||||
entity.setValInterval(
|
||||
createReq.getValInterval() != null
|
||||
? createReq.getValInterval()
|
||||
: baseEntity.getValInterval());
|
||||
entity.setLogInterval(
|
||||
createReq.getLogInterval() != null
|
||||
? createReq.getLogInterval()
|
||||
: baseEntity.getLogInterval());
|
||||
entity.setVisInterval(
|
||||
createReq.getVisInterval() != null
|
||||
? createReq.getVisInterval()
|
||||
: baseEntity.getVisInterval());
|
||||
entity.setMetrics(createReq.getMetrics());
|
||||
entity.setSaveBest(createReq.getSaveBest());
|
||||
entity.setSaveBestRule(createReq.getSaveBestRule());
|
||||
entity.setValInterval(createReq.getValInterval());
|
||||
entity.setLogInterval(createReq.getLogInterval());
|
||||
entity.setVisInterval(createReq.getVisInterval());
|
||||
|
||||
// Hardware
|
||||
entity.setGpuCnt(
|
||||
createReq.getGpuCnt() != null ? createReq.getGpuCnt() : baseEntity.getGpuCnt());
|
||||
entity.setGpuIds(
|
||||
createReq.getGpuIds() != null ? createReq.getGpuIds() : baseEntity.getGpuIds());
|
||||
entity.setMasterPort(
|
||||
createReq.getMasterPort() != null ? createReq.getMasterPort() : baseEntity.getMasterPort());
|
||||
|
||||
// Augmentation
|
||||
entity.setRotProb(
|
||||
createReq.getRotProb() != null ? createReq.getRotProb() : baseEntity.getRotProb());
|
||||
entity.setFlipProb(
|
||||
createReq.getFlipProb() != null ? createReq.getFlipProb() : baseEntity.getFlipProb());
|
||||
entity.setRotDegree(
|
||||
createReq.getRotDegree() != null ? createReq.getRotDegree() : baseEntity.getRotDegree());
|
||||
entity.setExchangeProb(
|
||||
createReq.getExchangeProb() != null
|
||||
? createReq.getExchangeProb()
|
||||
: baseEntity.getExchangeProb());
|
||||
entity.setBrightnessDelta(
|
||||
createReq.getBrightnessDelta() != null
|
||||
? createReq.getBrightnessDelta()
|
||||
: baseEntity.getBrightnessDelta());
|
||||
entity.setContrastRange(
|
||||
createReq.getContrastRange() != null
|
||||
? createReq.getContrastRange()
|
||||
: baseEntity.getContrastRange());
|
||||
entity.setSaturationRange(
|
||||
createReq.getSaturationRange() != null
|
||||
? createReq.getSaturationRange()
|
||||
: baseEntity.getSaturationRange());
|
||||
entity.setHueDelta(
|
||||
createReq.getHueDelta() != null ? createReq.getHueDelta() : baseEntity.getHueDelta());
|
||||
|
||||
// Common
|
||||
// memo
|
||||
entity.setMemo(createReq.getMemo());
|
||||
entity.setDelYn("N");
|
||||
entity.setCreatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
ModelHyperParamEntity saved = hyperParamRepository.save(entity);
|
||||
return saved.getHyperVer();
|
||||
/**
|
||||
* 하이퍼파라미터 삭제
|
||||
*
|
||||
* @param uuid
|
||||
*/
|
||||
public void deleteHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
throw new CustomApiException("CONFLICT", HttpStatus.CONFLICT, "HPs_0001 버전은 삭제할수 없습니다.");
|
||||
}
|
||||
|
||||
entity.setDelYn("Y");
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -261,27 +188,16 @@ public class HyperParamCoreService {
|
||||
* @return 하이퍼파라미터 정보
|
||||
*/
|
||||
public ModelMngDto.HyperParamInfo findByHyperVer(String hyperVer) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findById(hyperVer)
|
||||
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
|
||||
// ModelHyperParamEntity entity =
|
||||
// hyperParamRepository
|
||||
// .findById(hyperVer)
|
||||
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
|
||||
//
|
||||
// if ("Y".equals(entity.getDelYn())) {
|
||||
// throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
|
||||
// }
|
||||
|
||||
if ("Y".equals(entity.getDelYn())) {
|
||||
throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
|
||||
}
|
||||
|
||||
return mapToHyperParamInfo(entity);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 수정 (기존 버전은 수정 불가)
|
||||
*
|
||||
* @param hyperVer 하이퍼파라미터 버전
|
||||
* @param updateReq 수정 요청
|
||||
*/
|
||||
public void updateHyperParam(String hyperVer, ModelMngDto.HyperParamCreateReq updateReq) {
|
||||
// 기존 버전은 수정 불가
|
||||
throw new BadRequestException("기존 버전은 수정할 수 없습니다. 신규 버전을 생성해주세요.");
|
||||
return mapToHyperParamInfo(null);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -294,33 +210,47 @@ public class HyperParamCoreService {
|
||||
if ("H1".equals(hyperVer)) {
|
||||
throw new BadRequestException("H1은 디폴트 하이퍼파라미터 버전이므로 삭제할 수 없습니다.");
|
||||
}
|
||||
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findById(hyperVer)
|
||||
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
|
||||
|
||||
if ("Y".equals(entity.getDelYn())) {
|
||||
throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
|
||||
}
|
||||
|
||||
// 논리 삭제 처리
|
||||
entity.setDelYn("Y");
|
||||
hyperParamRepository.save(entity);
|
||||
//
|
||||
// ModelHyperParamEntity entity =
|
||||
// hyperParamRepository
|
||||
// .findById(hyperVer)
|
||||
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
|
||||
//
|
||||
// if ("Y".equals(entity.getDelYn())) {
|
||||
// throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
|
||||
// }
|
||||
//
|
||||
// // 논리 삭제 처리
|
||||
// entity.setDelYn("Y");
|
||||
// hyperParamRepository.save(entity);
|
||||
}
|
||||
|
||||
/**
|
||||
* 첫 번째 하이퍼파라미터 버전 조회 (H1 확인용)
|
||||
* 하이퍼파라미터 목록 조회
|
||||
*
|
||||
* @return 첫 번째 하이퍼파라미터 버전
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
|
||||
return hyperParamRepository.findByHyperVerList(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 버전 조회
|
||||
*
|
||||
* @return ver
|
||||
*/
|
||||
public String getFirstHyperParamVersion() {
|
||||
List<ModelHyperParamEntity> entities =
|
||||
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
|
||||
if (entities.isEmpty()) {
|
||||
throw new NotFoundException("하이퍼파라미터가 존재하지 않습니다.");
|
||||
}
|
||||
// 가장 오래된 것이 H1이므로 리스트의 마지막 요소 반환
|
||||
return entities.get(entities.size() - 1).getHyperVer();
|
||||
return hyperParamRepository
|
||||
.findHyperParamVer()
|
||||
.map(ModelHyperParamEntity::getHyperVer)
|
||||
.map(this::increase)
|
||||
.orElse("HPs_0001");
|
||||
}
|
||||
|
||||
private String increase(String hyperVer) {
|
||||
String prefix = "HPs_";
|
||||
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
|
||||
return prefix + String.format("%04d", num + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import jakarta.persistence.*;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
@@ -28,7 +29,7 @@ public class ModelHyperParamEntity {
|
||||
@NotNull
|
||||
@UuidGenerator
|
||||
@Column(name = "uuid", nullable = false, updatable = false)
|
||||
private UUID uuid;
|
||||
private UUID uuid = UUID.randomUUID();
|
||||
|
||||
@Size(max = 50)
|
||||
@NotNull
|
||||
@@ -259,8 +260,7 @@ public class ModelHyperParamEntity {
|
||||
// -------------------------
|
||||
|
||||
/** Default: 4 */
|
||||
@NotNull
|
||||
@Column(name = "gpu_cnt", nullable = false)
|
||||
@Column(name = "gpu_cnt")
|
||||
private Integer gpuCnt;
|
||||
|
||||
/** Default: 0,1,2,3 */
|
||||
@@ -289,9 +289,86 @@ public class ModelHyperParamEntity {
|
||||
@Column(name = "created_dttm", nullable = false)
|
||||
private ZonedDateTime createdDttm = ZonedDateTime.now();
|
||||
|
||||
@Column(name = "cnn_filter_cnt")
|
||||
private Integer cnnFilterCnt;
|
||||
@NotNull
|
||||
@Column(name = "created_uid", nullable = false)
|
||||
private Long createdUid;
|
||||
|
||||
@ColumnDefault("CURRENT_TIMESTAMP")
|
||||
@Column(name = "updated_dttm")
|
||||
private ZonedDateTime updatedDttm;
|
||||
|
||||
@Column(name = "updated_uid")
|
||||
private Long updatedUid;
|
||||
|
||||
@ColumnDefault("CURRENT_TIMESTAMP")
|
||||
@Column(name = "last_used_dttm")
|
||||
private ZonedDateTime lastUsedDttm;
|
||||
|
||||
@Column(name = "m1_use_cnt")
|
||||
private Long m1UseCnt = 0L;
|
||||
|
||||
@Column(name = "m2_use_cnt")
|
||||
private Long m2UseCnt = 0L;
|
||||
|
||||
@Column(name = "m3_use_cnt")
|
||||
private Long m3UseCnt = 0L;
|
||||
|
||||
@OneToMany(mappedBy = "hyperParams", fetch = FetchType.LAZY)
|
||||
private Set<ModelTrainMasterEntity> trainMasters = new LinkedHashSet<>();
|
||||
|
||||
public HyperParamDto.Basic toDto() {
|
||||
return new HyperParamDto.Basic(
|
||||
this.id,
|
||||
this.uuid,
|
||||
this.hyperVer,
|
||||
this.backbone,
|
||||
this.inputSize,
|
||||
this.cropSize,
|
||||
this.epochCnt,
|
||||
this.batchSize,
|
||||
this.dropPathRate,
|
||||
this.frozenStages,
|
||||
this.neckPolicy,
|
||||
this.decoderChannels,
|
||||
this.classWeight,
|
||||
this.numLayers,
|
||||
this.learningRate,
|
||||
this.weightDecay,
|
||||
this.layerDecayRate,
|
||||
this.ddpFindUnusedParams,
|
||||
this.ignoreIndex,
|
||||
this.trainNumWorkers,
|
||||
this.valNumWorkers,
|
||||
this.testNumWorkers,
|
||||
this.trainShuffle,
|
||||
this.trainPersistent,
|
||||
this.valPersistent,
|
||||
this.metrics,
|
||||
this.saveBest,
|
||||
this.saveBestRule,
|
||||
this.valInterval,
|
||||
this.logInterval,
|
||||
this.visInterval,
|
||||
this.rotProb,
|
||||
this.flipProb,
|
||||
this.rotDegree,
|
||||
this.exchangeProb,
|
||||
this.brightnessDelta,
|
||||
this.contrastRange,
|
||||
this.saturationRange,
|
||||
this.hueDelta,
|
||||
this.gpuCnt,
|
||||
this.gpuIds,
|
||||
this.masterPort,
|
||||
this.memo,
|
||||
this.delYn,
|
||||
this.createdDttm,
|
||||
this.createdUid,
|
||||
this.updatedDttm,
|
||||
this.updatedUid,
|
||||
this.lastUsedDttm,
|
||||
this.m1UseCnt,
|
||||
this.m2UseCnt,
|
||||
this.m3UseCnt);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import jakarta.persistence.ManyToOne;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.Instant;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.UUID;
|
||||
import lombok.Getter;
|
||||
@@ -112,7 +111,7 @@ public class ModelTrainMasterEntity {
|
||||
private Integer progressRate;
|
||||
|
||||
@Column(name = "stop_dttm")
|
||||
private Instant stopDttm;
|
||||
private ZonedDateTime stopDttm;
|
||||
|
||||
@Column(name = "confirmed_best_epoch")
|
||||
private Integer confirmedBestEpoch;
|
||||
@@ -125,7 +124,7 @@ public class ModelTrainMasterEntity {
|
||||
private String errorMsg;
|
||||
|
||||
@Column(name = "step2_start_dttm")
|
||||
private Instant step2StartDttm;
|
||||
private ZonedDateTime step2StartDttm;
|
||||
|
||||
@Size(max = 1000)
|
||||
@Column(name = "train_log_path", length = 1000)
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.kamco.cd.training.postgres.repository.hyperparam;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public interface HyperParamRepository
|
||||
extends JpaRepository<ModelHyperParamEntity, Long>, HyperParamRepositoryCustom {}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.kamco.cd.training.postgres.repository.hyperparam;
|
||||
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Page;
|
||||
|
||||
public interface HyperParamRepositoryCustom {
|
||||
|
||||
/**
|
||||
* 마지막 버전 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
Optional<ModelHyperParamEntity> findHyperParamVer();
|
||||
|
||||
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
|
||||
|
||||
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package com.kamco.cd.training.postgres.repository.hyperparam;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.NumberExpression;
|
||||
import com.querydsl.jpa.impl.JPAQuery;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.time.ZoneId;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.data.domain.PageImpl;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
import org.springframework.data.domain.Sort;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
|
||||
private final JPAQueryFactory queryFactory;
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> findHyperParamVer() {
|
||||
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.eq("N"))
|
||||
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
||||
.limit(1)
|
||||
.fetchOne());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.eq("N").and(modelHyperParamEntity.uuid.eq(uuid)))
|
||||
.fetchOne());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
|
||||
Pageable pageable = req.toPageable();
|
||||
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
builder.and(modelHyperParamEntity.delYn.eq("N"));
|
||||
|
||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||
// 버전
|
||||
builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer()));
|
||||
}
|
||||
|
||||
if (req.getStartDate() != null && req.getEndDate() != null) {
|
||||
|
||||
ZoneId zoneId = ZoneId.systemDefault();
|
||||
|
||||
ZonedDateTime start = req.getStartDate().atStartOfDay(zoneId);
|
||||
|
||||
ZonedDateTime end = req.getEndDate().atTime(23, 59, 59).atZone(zoneId);
|
||||
|
||||
if (HyperType.CREATE_DATE.getId().equals(req.getType())) {
|
||||
// 생성일
|
||||
builder.and(modelHyperParamEntity.createdDttm.between(start, end));
|
||||
} else if (HyperType.LAST_USED_DATE.getId().equals(req.getType())) {
|
||||
// 최종 사용일
|
||||
builder.and(modelHyperParamEntity.lastUsedDttm.between(start, end));
|
||||
}
|
||||
}
|
||||
|
||||
NumberExpression<Long> totalUseCnt =
|
||||
modelHyperParamEntity
|
||||
.m1UseCnt
|
||||
.coalesce(0L)
|
||||
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
|
||||
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
|
||||
|
||||
JPAQuery<HyperParamDto.List> query =
|
||||
queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
HyperParamDto.List.class,
|
||||
modelHyperParamEntity.uuid,
|
||||
modelHyperParamEntity.hyperVer,
|
||||
modelHyperParamEntity.createdDttm,
|
||||
modelHyperParamEntity.lastUsedDttm,
|
||||
modelHyperParamEntity.m1UseCnt,
|
||||
modelHyperParamEntity.m2UseCnt,
|
||||
modelHyperParamEntity.m3UseCnt,
|
||||
totalUseCnt.as("totalUseCnt")))
|
||||
.from(modelHyperParamEntity)
|
||||
.where(builder);
|
||||
|
||||
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
|
||||
|
||||
if (sortOrder == null) {
|
||||
// 기본값
|
||||
query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
||||
} else {
|
||||
String property = sortOrder.getProperty();
|
||||
boolean asc = sortOrder.isAscending();
|
||||
|
||||
switch (property) {
|
||||
case "createdDttm" ->
|
||||
query.orderBy(
|
||||
asc
|
||||
? modelHyperParamEntity.createdDttm.asc()
|
||||
: modelHyperParamEntity.createdDttm.desc());
|
||||
|
||||
case "lastUsedDttm" ->
|
||||
query.orderBy(
|
||||
asc
|
||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||
|
||||
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
|
||||
|
||||
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
||||
}
|
||||
}
|
||||
|
||||
List<HyperParamDto.List> content =
|
||||
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
|
||||
|
||||
Long total =
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity.count())
|
||||
.from(modelHyperParamEntity)
|
||||
.where(builder)
|
||||
.fetchOne();
|
||||
|
||||
long totalCount = (total != null) ? total : 0L;
|
||||
|
||||
return new PageImpl<>(content, pageable, totalCount);
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import java.util.List;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public interface ModelHyperParamRepository extends JpaRepository<ModelHyperParamEntity, String> {
|
||||
|
||||
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmDesc(String delYn);
|
||||
|
||||
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmAsc(String delYn);
|
||||
}
|
||||
Reference in New Issue
Block a user