하이퍼파라미터 기능 추가

This commit is contained in:
2026-02-03 14:31:53 +09:00
parent e2757d3ca0
commit 3a8d6e3ef0
18 changed files with 946 additions and 688 deletions

View File

@@ -1,31 +1,24 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.kamco.cd.training.postgres.repository.model.ModelHyperParamRepository;
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
public class HyperParamCoreService {
private final ModelHyperParamRepository hyperParamRepository;
/**
* 하이퍼파라미터 전체 조회 (삭제되지 않은 것만)
*
* @return 하이퍼파라미터 목록
*/
public List<ModelMngDto.HyperParamInfo> findAllActiveHyperParams() {
List<ModelHyperParamEntity> entities =
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
return entities.stream().map(this::mapToHyperParamInfo).toList();
}
private final HyperParamRepository hyperParamRepository;
private final UserUtil userUtil;
private ModelMngDto.HyperParamInfo mapToHyperParamInfo(ModelHyperParamEntity entity) {
return ModelMngDto.HyperParamInfo.builder()
@@ -76,8 +69,6 @@ public class HyperParamCoreService {
.contrastRange(entity.getContrastRange())
.saturationRange(entity.getSaturationRange())
.hueDelta(entity.getHueDelta())
// Legacy
.cnnFilterCnt(entity.getCnnFilterCnt())
// Common
.memo(entity.getMemo())
.createdDttm(entity.getCreatedDttm())
@@ -90,168 +81,104 @@ public class HyperParamCoreService {
* @param createReq 등록 요청
* @return 등록된 버전명
*/
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
// 중복 체크
if (hyperParamRepository.existsById(createReq.getNewHyperVer())) {
throw new BadRequestException("이미 존재하는 버전입니다: " + createReq.getNewHyperVer());
}
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
String firstVersion = getFirstHyperParamVersion();
// 기준 버전 조회
ModelHyperParamEntity baseEntity =
hyperParamRepository
.findById(createReq.getBaseHyperVer())
.orElseThrow(
() -> new NotFoundException("기준 버전을 찾을 수 없습니다: " + createReq.getBaseHyperVer()));
// 신규 엔티티 생성 (기준 값 복사 후 변경된 값만 적용)
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(createReq.getNewHyperVer());
entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq);
// user
entity.setCreatedUid(userUtil.getId());
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
return resultEntity.getHyperVer();
}
/**
* 하이퍼파라미터 수정
*
* @param uuid uuid
* @param createReq 등록 요청
* @return ver
*/
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
applyHyperParam(entity, createReq);
// user
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
return entity.getHyperVer();
}
private void applyHyperParam(
ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) {
// Important
entity.setBackbone(
createReq.getBackbone() != null ? createReq.getBackbone() : baseEntity.getBackbone());
entity.setInputSize(
createReq.getInputSize() != null ? createReq.getInputSize() : baseEntity.getInputSize());
entity.setCropSize(
createReq.getCropSize() != null ? createReq.getCropSize() : baseEntity.getCropSize());
entity.setEpochCnt(
createReq.getEpochCnt() != null ? createReq.getEpochCnt() : baseEntity.getEpochCnt());
entity.setBatchSize(
createReq.getBatchSize() != null ? createReq.getBatchSize() : baseEntity.getBatchSize());
// Architecture
entity.setDropPathRate(
createReq.getDropPathRate() != null
? createReq.getDropPathRate()
: baseEntity.getDropPathRate());
entity.setFrozenStages(
createReq.getFrozenStages() != null
? createReq.getFrozenStages()
: baseEntity.getFrozenStages());
entity.setNeckPolicy(
createReq.getNeckPolicy() != null ? createReq.getNeckPolicy() : baseEntity.getNeckPolicy());
entity.setDecoderChannels(
createReq.getDecoderChannels() != null
? createReq.getDecoderChannels()
: baseEntity.getDecoderChannels());
entity.setClassWeight(
createReq.getClassWeight() != null
? createReq.getClassWeight()
: baseEntity.getClassWeight());
entity.setNumLayers(
createReq.getNumLayers() != null ? createReq.getNumLayers() : baseEntity.getNumLayers());
// Optimization
entity.setLearningRate(
createReq.getLearningRate() != null
? createReq.getLearningRate()
: baseEntity.getLearningRate());
entity.setWeightDecay(
createReq.getWeightDecay() != null
? createReq.getWeightDecay()
: baseEntity.getWeightDecay());
entity.setLayerDecayRate(
createReq.getLayerDecayRate() != null
? createReq.getLayerDecayRate()
: baseEntity.getLayerDecayRate());
entity.setDdpFindUnusedParams(
createReq.getDdpFindUnusedParams() != null
? createReq.getDdpFindUnusedParams()
: baseEntity.getDdpFindUnusedParams());
entity.setIgnoreIndex(
createReq.getIgnoreIndex() != null
? createReq.getIgnoreIndex()
: baseEntity.getIgnoreIndex());
entity.setBackbone(createReq.getBackbone());
entity.setInputSize(createReq.getInputSize());
entity.setCropSize(createReq.getCropSize());
entity.setBatchSize(createReq.getBatchSize());
// Data
entity.setTrainNumWorkers(
createReq.getTrainNumWorkers() != null
? createReq.getTrainNumWorkers()
: baseEntity.getTrainNumWorkers());
entity.setValNumWorkers(
createReq.getValNumWorkers() != null
? createReq.getValNumWorkers()
: baseEntity.getValNumWorkers());
entity.setTestNumWorkers(
createReq.getTestNumWorkers() != null
? createReq.getTestNumWorkers()
: baseEntity.getTestNumWorkers());
entity.setTrainShuffle(
createReq.getTrainShuffle() != null
? createReq.getTrainShuffle()
: baseEntity.getTrainShuffle());
entity.setTrainPersistent(
createReq.getTrainPersistent() != null
? createReq.getTrainPersistent()
: baseEntity.getTrainPersistent());
entity.setValPersistent(
createReq.getValPersistent() != null
? createReq.getValPersistent()
: baseEntity.getValPersistent());
entity.setTrainNumWorkers(createReq.getTrainNumWorkers());
entity.setValNumWorkers(createReq.getValNumWorkers());
entity.setTestNumWorkers(createReq.getTestNumWorkers());
entity.setTrainShuffle(createReq.getTrainShuffle());
entity.setTrainPersistent(createReq.getTrainPersistent());
entity.setValPersistent(createReq.getValPersistent());
// Model Architecture
entity.setDropPathRate(createReq.getDropPathRate());
entity.setFrozenStages(createReq.getFrozenStages());
entity.setNeckPolicy(createReq.getNeckPolicy());
entity.setClassWeight(createReq.getClassWeight());
entity.setDecoderChannels(createReq.getDecoderChannels());
// Loss & Optimization
entity.setLearningRate(createReq.getLearningRate());
entity.setWeightDecay(createReq.getWeightDecay());
entity.setLayerDecayRate(createReq.getLayerDecayRate());
entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams());
entity.setIgnoreIndex(createReq.getIgnoreIndex());
entity.setNumLayers(createReq.getNumLayers());
// Evaluation
entity.setMetrics(
createReq.getMetrics() != null ? createReq.getMetrics() : baseEntity.getMetrics());
entity.setSaveBest(
createReq.getSaveBest() != null ? createReq.getSaveBest() : baseEntity.getSaveBest());
entity.setSaveBestRule(
createReq.getSaveBestRule() != null
? createReq.getSaveBestRule()
: baseEntity.getSaveBestRule());
entity.setValInterval(
createReq.getValInterval() != null
? createReq.getValInterval()
: baseEntity.getValInterval());
entity.setLogInterval(
createReq.getLogInterval() != null
? createReq.getLogInterval()
: baseEntity.getLogInterval());
entity.setVisInterval(
createReq.getVisInterval() != null
? createReq.getVisInterval()
: baseEntity.getVisInterval());
entity.setMetrics(createReq.getMetrics());
entity.setSaveBest(createReq.getSaveBest());
entity.setSaveBestRule(createReq.getSaveBestRule());
entity.setValInterval(createReq.getValInterval());
entity.setLogInterval(createReq.getLogInterval());
entity.setVisInterval(createReq.getVisInterval());
// Hardware
entity.setGpuCnt(
createReq.getGpuCnt() != null ? createReq.getGpuCnt() : baseEntity.getGpuCnt());
entity.setGpuIds(
createReq.getGpuIds() != null ? createReq.getGpuIds() : baseEntity.getGpuIds());
entity.setMasterPort(
createReq.getMasterPort() != null ? createReq.getMasterPort() : baseEntity.getMasterPort());
// Augmentation
entity.setRotProb(
createReq.getRotProb() != null ? createReq.getRotProb() : baseEntity.getRotProb());
entity.setFlipProb(
createReq.getFlipProb() != null ? createReq.getFlipProb() : baseEntity.getFlipProb());
entity.setRotDegree(
createReq.getRotDegree() != null ? createReq.getRotDegree() : baseEntity.getRotDegree());
entity.setExchangeProb(
createReq.getExchangeProb() != null
? createReq.getExchangeProb()
: baseEntity.getExchangeProb());
entity.setBrightnessDelta(
createReq.getBrightnessDelta() != null
? createReq.getBrightnessDelta()
: baseEntity.getBrightnessDelta());
entity.setContrastRange(
createReq.getContrastRange() != null
? createReq.getContrastRange()
: baseEntity.getContrastRange());
entity.setSaturationRange(
createReq.getSaturationRange() != null
? createReq.getSaturationRange()
: baseEntity.getSaturationRange());
entity.setHueDelta(
createReq.getHueDelta() != null ? createReq.getHueDelta() : baseEntity.getHueDelta());
// Common
// memo
entity.setMemo(createReq.getMemo());
entity.setDelYn("N");
entity.setCreatedDttm(ZonedDateTime.now());
}
ModelHyperParamEntity saved = hyperParamRepository.save(entity);
return saved.getHyperVer();
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
throw new CustomApiException("CONFLICT", HttpStatus.CONFLICT, "HPs_0001 버전은 삭제할수 없습니다.");
}
entity.setDelYn("Y");
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
@@ -261,27 +188,16 @@ public class HyperParamCoreService {
* @return 하이퍼파라미터 정보
*/
public ModelMngDto.HyperParamInfo findByHyperVer(String hyperVer) {
ModelHyperParamEntity entity =
hyperParamRepository
.findById(hyperVer)
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
// ModelHyperParamEntity entity =
// hyperParamRepository
// .findById(hyperVer)
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
//
// if ("Y".equals(entity.getDelYn())) {
// throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
// }
if ("Y".equals(entity.getDelYn())) {
throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
}
return mapToHyperParamInfo(entity);
}
/**
* 하이퍼파라미터 수정 (기존 버전은 수정 불가)
*
* @param hyperVer 하이퍼파라미터 버전
* @param updateReq 수정 요청
*/
public void updateHyperParam(String hyperVer, ModelMngDto.HyperParamCreateReq updateReq) {
// 기존 버전은 수정 불가
throw new BadRequestException("기존 버전은 수정할 수 없습니다. 신규 버전을 생성해주세요.");
return mapToHyperParamInfo(null);
}
/**
@@ -294,33 +210,47 @@ public class HyperParamCoreService {
if ("H1".equals(hyperVer)) {
throw new BadRequestException("H1은 디폴트 하이퍼파라미터 버전이므로 삭제할 수 없습니다.");
}
ModelHyperParamEntity entity =
hyperParamRepository
.findById(hyperVer)
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
if ("Y".equals(entity.getDelYn())) {
throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
}
// 논리 삭제 처리
entity.setDelYn("Y");
hyperParamRepository.save(entity);
//
// ModelHyperParamEntity entity =
// hyperParamRepository
// .findById(hyperVer)
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
//
// if ("Y".equals(entity.getDelYn())) {
// throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
// }
//
// // 논리 삭제 처리
// entity.setDelYn("Y");
// hyperParamRepository.save(entity);
}
/**
* 첫 번째 하이퍼파라미터 버전 조회 (H1 확인용)
* 하이퍼파라미터 목록 조회
*
* @return 첫 번째 하이퍼파라미터 버전
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
return hyperParamRepository.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @return ver
*/
public String getFirstHyperParamVersion() {
List<ModelHyperParamEntity> entities =
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
if (entities.isEmpty()) {
throw new NotFoundException("하이퍼파라미터가 존재하지 않습니다.");
}
// 가장 오래된 것이 H1이므로 리스트의 마지막 요소 반환
return entities.get(entities.size() - 1).getHyperVer();
return hyperParamRepository
.findHyperParamVer()
.map(ModelHyperParamEntity::getHyperVer)
.map(this::increase)
.orElse("HPs_0001");
}
private String increase(String hyperVer) {
String prefix = "HPs_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%04d", num + 1);
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import jakarta.persistence.*;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
@@ -28,7 +29,7 @@ public class ModelHyperParamEntity {
@NotNull
@UuidGenerator
@Column(name = "uuid", nullable = false, updatable = false)
private UUID uuid;
private UUID uuid = UUID.randomUUID();
@Size(max = 50)
@NotNull
@@ -259,8 +260,7 @@ public class ModelHyperParamEntity {
// -------------------------
/** Default: 4 */
@NotNull
@Column(name = "gpu_cnt", nullable = false)
@Column(name = "gpu_cnt")
private Integer gpuCnt;
/** Default: 0,1,2,3 */
@@ -289,9 +289,86 @@ public class ModelHyperParamEntity {
@Column(name = "created_dttm", nullable = false)
private ZonedDateTime createdDttm = ZonedDateTime.now();
@Column(name = "cnn_filter_cnt")
private Integer cnnFilterCnt;
@NotNull
@Column(name = "created_uid", nullable = false)
private Long createdUid;
@ColumnDefault("CURRENT_TIMESTAMP")
@Column(name = "updated_dttm")
private ZonedDateTime updatedDttm;
@Column(name = "updated_uid")
private Long updatedUid;
@ColumnDefault("CURRENT_TIMESTAMP")
@Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt")
private Long m1UseCnt = 0L;
@Column(name = "m2_use_cnt")
private Long m2UseCnt = 0L;
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@OneToMany(mappedBy = "hyperParams", fetch = FetchType.LAZY)
private Set<ModelTrainMasterEntity> trainMasters = new LinkedHashSet<>();
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.id,
this.uuid,
this.hyperVer,
this.backbone,
this.inputSize,
this.cropSize,
this.epochCnt,
this.batchSize,
this.dropPathRate,
this.frozenStages,
this.neckPolicy,
this.decoderChannels,
this.classWeight,
this.numLayers,
this.learningRate,
this.weightDecay,
this.layerDecayRate,
this.ddpFindUnusedParams,
this.ignoreIndex,
this.trainNumWorkers,
this.valNumWorkers,
this.testNumWorkers,
this.trainShuffle,
this.trainPersistent,
this.valPersistent,
this.metrics,
this.saveBest,
this.saveBestRule,
this.valInterval,
this.logInterval,
this.visInterval,
this.rotProb,
this.flipProb,
this.rotDegree,
this.exchangeProb,
this.brightnessDelta,
this.contrastRange,
this.saturationRange,
this.hueDelta,
this.gpuCnt,
this.gpuIds,
this.masterPort,
this.memo,
this.delYn,
this.createdDttm,
this.createdUid,
this.updatedDttm,
this.updatedUid,
this.lastUsedDttm,
this.m1UseCnt,
this.m2UseCnt,
this.m3UseCnt);
}
}

View File

@@ -12,7 +12,6 @@ import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.Getter;
@@ -112,7 +111,7 @@ public class ModelTrainMasterEntity {
private Integer progressRate;
@Column(name = "stop_dttm")
private Instant stopDttm;
private ZonedDateTime stopDttm;
@Column(name = "confirmed_best_epoch")
private Integer confirmedBestEpoch;
@@ -125,7 +124,7 @@ public class ModelTrainMasterEntity {
private String errorMsg;
@Column(name = "step2_start_dttm")
private Instant step2StartDttm;
private ZonedDateTime step2StartDttm;
@Size(max = 1000)
@Column(name = "train_log_path", length = 1000)

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface HyperParamRepository
extends JpaRepository<ModelHyperParamEntity, Long>, HyperParamRepositoryCustom {}

View File

@@ -0,0 +1,21 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
public interface HyperParamRepositoryCustom {
/**
* 마지막 버전 조회
*
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamVer();
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
}

View File

@@ -0,0 +1,148 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
private final JPAQueryFactory queryFactory;
@Override
public Optional<ModelHyperParamEntity> findHyperParamVer() {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.eq("N"))
.orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1)
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.eq("N").and(modelHyperParamEntity.uuid.eq(uuid)))
.fetchOne());
}
@Override
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.eq("N"));
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전
builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer()));
}
if (req.getStartDate() != null && req.getEndDate() != null) {
ZoneId zoneId = ZoneId.systemDefault();
ZonedDateTime start = req.getStartDate().atStartOfDay(zoneId);
ZonedDateTime end = req.getEndDate().atTime(23, 59, 59).atZone(zoneId);
if (HyperType.CREATE_DATE.getId().equals(req.getType())) {
// 생성일
builder.and(modelHyperParamEntity.createdDttm.between(start, end));
} else if (HyperType.LAST_USED_DATE.getId().equals(req.getType())) {
// 최종 사용일
builder.and(modelHyperParamEntity.lastUsedDttm.between(start, end));
}
}
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query =
queryFactory
.select(
Projections.constructor(
HyperParamDto.List.class,
modelHyperParamEntity.uuid,
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m2UseCnt,
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
.from(modelHyperParamEntity)
.where(builder);
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
if (sortOrder == null) {
// 기본값
query.orderBy(modelHyperParamEntity.createdDttm.desc());
} else {
String property = sortOrder.getProperty();
boolean asc = sortOrder.isAscending();
switch (property) {
case "createdDttm" ->
query.orderBy(
asc
? modelHyperParamEntity.createdDttm.asc()
: modelHyperParamEntity.createdDttm.desc());
case "lastUsedDttm" ->
query.orderBy(
asc
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}
}
List<HyperParamDto.List> content =
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
Long total =
queryFactory
.select(modelHyperParamEntity.count())
.from(modelHyperParamEntity)
.where(builder)
.fetchOne();
long totalCount = (total != null) ? total : 0L;
return new PageImpl<>(content, pageable, totalCount);
}
}

View File

@@ -1,14 +0,0 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelHyperParamRepository extends JpaRepository<ModelHyperParamEntity, String> {
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmDesc(String delYn);
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmAsc(String delYn);
}