hyperparam_with_modeltype

This commit is contained in:
2026-02-12 19:14:01 +09:00
parent 0bc4453c9c
commit d5b2b8ecec
6 changed files with 186 additions and 168 deletions

View File

@@ -88,9 +88,8 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("{model}/list") @GetMapping("/list")
public ApiResponseDto<Page<List>> getHyperParam( public ApiResponseDto<Page<List>> getHyperParam(
@PathVariable ModelType model,
@Parameter( @Parameter(
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)", description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
example = "CREATE_DATE") example = "CREATE_DATE")
@@ -102,7 +101,8 @@ public class HyperParamApiController {
LocalDate endDate, LocalDate endDate,
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false) @Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
String hyperVer, String hyperVer,
@Parameter( @Parameter(description = "버전명", example = "G1,G2,G3") @RequestParam(required = false) ModelType model
, @Parameter(
description = "정렬", description = "정렬",
example = "createdDttm desc", example = "createdDttm desc",
schema = schema =

View File

@@ -110,9 +110,11 @@ public class HyperParamDto {
@AllArgsConstructor @AllArgsConstructor
public static class List { public static class List {
private UUID uuid; private UUID uuid;
private ModelType model;
private String hyperVer; private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm; @JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm; @JsonFormatDttm private ZonedDateTime lastUsedDttm;
private String memo;
private Long m1UseCnt; private Long m1UseCnt;
private Long m2UseCnt; private Long m2UseCnt;
private Long m3UseCnt; private Long m3UseCnt;

View File

@@ -50,15 +50,15 @@ public class HyperParamCoreService {
/** /**
* 하이퍼파라미터 수정 * 하이퍼파라미터 수정
* *
* @param uuid uuid * @param uuid uuid
* @param createReq 등록 요청 * @param createReq 등록 요청
* @return ver * @return ver
*/ */
public String updateHyperParam(UUID uuid, HyperParam createReq) { public String updateHyperParam(UUID uuid, HyperParam createReq) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getIsDefault()) { if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY); throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
@@ -79,9 +79,9 @@ public class HyperParamCoreService {
*/ */
public void deleteHyperParam(UUID uuid) { public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) { // if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); // throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
@@ -104,9 +104,10 @@ public class HyperParamCoreService {
*/ */
public HyperParamDto.Basic getInitHyperParam(ModelType model) { public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.getHyperparamByType(model) .getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .stream().filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst().orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto(); return entity.toDto();
} }
@@ -117,9 +118,9 @@ public class HyperParamCoreService {
*/ */
public HyperParamDto.Basic getHyperParam(UUID uuid) { public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto(); return entity.toDto();
} }
@@ -142,16 +143,16 @@ public class HyperParamCoreService {
*/ */
public String getFirstHyperParamVersion(ModelType model) { public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository return hyperParamRepository
.findHyperParamVerByModelType(model) .findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer) .map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model)) .map(ver -> increase(ver, model))
.orElse(model.name() + "_000001"); .orElse(model.name() + "_000001");
} }
/** /**
* 하이퍼 파라미터의 버전을 증가시킨다. * 하이퍼 파라미터의 버전을 증가시킨다.
* *
* @param hyperVer 현재 버전 * @param hyperVer 현재 버전
* @param modelType 모델 타입 * @param modelType 모델 타입
* @return 증가된 버전 * @return 증가된 버전
*/ */

View File

@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
*/ */
public void deleteModel(UUID uuid) { public void deleteModel(UUID uuid) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
entity.setDelYn(true); entity.setDelYn(true);
entity.setUpdatedDttm(ZonedDateTime.now()); entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId()); entity.setUpdatedUid(userUtil.getId());
@@ -86,12 +86,15 @@ public class ModelTrainMngCoreService {
// 최적화 파라미터는 모델 type의 디폴트사용 // 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) { if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo()); ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null); hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType)
.stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst().orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); // hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else { } else {
hyperParamEntity = hyperParamEntity =
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null); hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
} }
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) { if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
@@ -99,8 +102,8 @@ public class ModelTrainMngCoreService {
} }
String modelVer = String modelVer =
String.join( String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString()); ".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
entity.setModelVer(modelVer); entity.setModelVer(modelVer);
entity.setHyperParamId(hyperParamEntity.getId()); entity.setHyperParamId(hyperParamEntity.getId());
entity.setModelNo(addReq.getModelNo()); entity.setModelNo(addReq.getModelNo());
@@ -132,7 +135,7 @@ public class ModelTrainMngCoreService {
* data set 저장 * data set 저장
* *
* @param modelId 저장한 모델 학습 id * @param modelId 저장한 모델 학습 id
* @param addReq 요청 파라미터 * @param addReq 요청 파라미터
*/ */
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) { public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
TrainingDataset dataset = addReq.getTrainingDataset(); TrainingDataset dataset = addReq.getTrainingDataset();
@@ -165,9 +168,9 @@ public class ModelTrainMngCoreService {
*/ */
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) { public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) { if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath()); entity.setRequestPath(req.getRequestPath());
} }
@@ -180,7 +183,7 @@ public class ModelTrainMngCoreService {
/** /**
* 모델 데이터셋 mapping 테이블 저장 * 모델 데이터셋 mapping 테이블 저장
* *
* @param modelId 모델학습 id * @param modelId 모델학습 id
* @param datasetList 선택한 data set * @param datasetList 선택한 data set
*/ */
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) { public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
@@ -197,7 +200,7 @@ public class ModelTrainMngCoreService {
* 모델학습 config 저장 * 모델학습 config 저장
* *
* @param modelId 모델학습 id * @param modelId 모델학습 id
* @param req 요청 파라미터 * @param req 요청 파라미터
* @return * @return
*/ */
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) { public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
@@ -217,7 +220,7 @@ public class ModelTrainMngCoreService {
/** /**
* 데이터셋 매핑 생성 * 데이터셋 매핑 생성
* *
* @param modelUid 모델 UID * @param modelUid 모델 UID
* @param datasetIds 데이터셋 ID 목록 * @param datasetIds 데이터셋 ID 목록
*/ */
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) { public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
@@ -239,8 +242,8 @@ public class ModelTrainMngCoreService {
public ModelMasterEntity findByUuid(UUID uuid) { public ModelMasterEntity findByUuid(UUID uuid) {
try { try {
return modelMngRepository return modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid); throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
} }
@@ -254,9 +257,9 @@ public class ModelTrainMngCoreService {
*/ */
public Long findModelIdByUuid(UUID uuid) { public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId(); return entity.getId();
} }
@@ -269,8 +272,8 @@ public class ModelTrainMngCoreService {
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) { public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid); ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository return modelConfigRepository
.findModelConfigByModelId(modelEntity.getId()) .findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} }
/** /**
@@ -301,19 +304,21 @@ public class ModelTrainMngCoreService {
*/ */
public ModelTrainMngDto.Basic findModelById(Long id) { public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(id) .findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto(); return entity.toDto();
} }
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */ /**
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
*/
@Transactional @Transactional
public void markInProgress(Long modelId, Long jobId) { public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId); master.setCurrentAttemptId(jobId);
@@ -321,46 +326,54 @@ public class ModelTrainMngCoreService {
// 필요하면 시작시간도 여기서 찍어줌 // 필요하면 시작시간도 여기서 찍어줌
} }
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */ /**
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
*/
@Transactional @Transactional
public void clearLastError(Long modelId) { public void clearLastError(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null); master.setLastError(null);
} }
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */ /**
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
*/
@Transactional @Transactional
public void markStopped(Long modelId) { public void markStopped(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId()); master.setStatusCd(TrainStatusType.STOPPED.getId());
} }
/** 완료 처리(옵션) - Worker가 성공 시 호출 */ /**
* 완료 처리(옵션) - Worker가 성공 시 호출
*/
@Transactional @Transactional
public void markCompleted(Long modelId) { public void markCompleted(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
} }
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markError(Long modelId, String errorMessage) { public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId()); master.setStep1State(TrainStatusType.ERROR.getId());
@@ -369,13 +382,15 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now()); master.setUpdatedDttm(ZonedDateTime.now());
} }
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markStep2Error(Long modelId, String errorMessage) { public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId()); master.setStep2State(TrainStatusType.ERROR.getId());
@@ -387,9 +402,9 @@ public class ModelTrainMngCoreService {
@Transactional @Transactional
public void markSuccess(Long modelId) { public void markSuccess(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리 // 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
@@ -417,9 +432,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) { public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now()); entity.setStep1StrtDttm(ZonedDateTime.now());
@@ -437,9 +452,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) { public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now()); entity.setStep2StrtDttm(ZonedDateTime.now());
@@ -457,9 +472,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) { public void markStep1Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId()); entity.setStep1State(TrainStatusType.COMPLETED.getId());
@@ -476,9 +491,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) { public void markStep2Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId()); entity.setStep2State(TrainStatusType.COMPLETED.getId());
@@ -489,9 +504,9 @@ public class ModelTrainMngCoreService {
public void updateModelMasterBestEpoch(Long modelId, int epoch) { public void updateModelMasterBestEpoch(Long modelId, int epoch) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setBestEpoch(epoch); entity.setBestEpoch(epoch);
} }

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -32,5 +33,5 @@ public interface HyperParamRepositoryCustom {
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req); Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType); List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
} }

View File

@@ -34,56 +34,56 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
public Optional<ModelHyperParamEntity> findHyperParamVer() { public Optional<ModelHyperParamEntity> findHyperParamVer() {
return Optional.ofNullable( return Optional.ofNullable(
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse()) .where(modelHyperParamEntity.delYn.isFalse())
.orderBy(modelHyperParamEntity.hyperVer.desc()) .orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1) .limit(1)
.fetchOne()); .fetchOne());
} }
@Override @Override
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) { public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
return Optional.ofNullable( return Optional.ofNullable(
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where( .where(
modelHyperParamEntity modelHyperParamEntity
.delYn .delYn
.isFalse() .isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType))) .and(modelHyperParamEntity.modelType.eq(modelType)))
.orderBy(modelHyperParamEntity.hyperVer.desc()) .orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1) .limit(1)
.fetchOne()); .fetchOne());
} }
@Override @Override
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) { public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
return Optional.ofNullable( return Optional.ofNullable(
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where( .where(
modelHyperParamEntity modelHyperParamEntity
.delYn .delYn
.isFalse() .isFalse()
.and(modelHyperParamEntity.hyperVer.eq(hyperVer))) .and(modelHyperParamEntity.hyperVer.eq(hyperVer)))
.limit(1) .limit(1)
.fetchOne()); .fetchOne());
} }
@Override @Override
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) { public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
return Optional.ofNullable( return Optional.ofNullable(
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid))) .where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
.fetchOne()); .fetchOne());
} }
@Override @Override
@@ -91,7 +91,9 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
Pageable pageable = req.toPageable(); Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.modelType.eq(model)); if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model));
}
builder.and(modelHyperParamEntity.delYn.isFalse()); builder.and(modelHyperParamEntity.delYn.isFalse());
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) { if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
@@ -117,27 +119,28 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
} }
NumberExpression<Long> totalUseCnt = NumberExpression<Long> totalUseCnt =
modelHyperParamEntity modelHyperParamEntity
.m1UseCnt .m1UseCnt
.coalesce(0L) .coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L)) .add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L)); .add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query = JPAQuery<HyperParamDto.List> query =
queryFactory queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
HyperParamDto.List.class, HyperParamDto.List.class,
modelHyperParamEntity.uuid, modelHyperParamEntity.uuid,
modelHyperParamEntity.hyperVer, modelHyperParamEntity.modelType.as("model"),
modelHyperParamEntity.createdDttm, modelHyperParamEntity.hyperVer,
modelHyperParamEntity.lastUsedDttm, modelHyperParamEntity.createdDttm,
modelHyperParamEntity.m1UseCnt, modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.m2UseCnt, modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m3UseCnt, modelHyperParamEntity.m2UseCnt,
totalUseCnt.as("totalUseCnt"))) modelHyperParamEntity.m3UseCnt,
.from(modelHyperParamEntity) totalUseCnt.as("totalUseCnt")))
.where(builder); .from(modelHyperParamEntity)
.where(builder);
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null); Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
@@ -149,17 +152,15 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
boolean asc = sortOrder.isAscending(); boolean asc = sortOrder.isAscending();
switch (property) { switch (property) {
case "createdDttm" -> case "createdDttm" -> query.orderBy(
query.orderBy( asc
asc ? modelHyperParamEntity.createdDttm.asc()
? modelHyperParamEntity.createdDttm.asc() : modelHyperParamEntity.createdDttm.desc());
: modelHyperParamEntity.createdDttm.desc());
case "lastUsedDttm" -> case "lastUsedDttm" -> query.orderBy(
query.orderBy( asc
asc ? modelHyperParamEntity.lastUsedDttm.asc()
? modelHyperParamEntity.lastUsedDttm.asc() : modelHyperParamEntity.lastUsedDttm.desc());
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc()); case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
@@ -168,14 +169,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
} }
List<HyperParamDto.List> content = List<HyperParamDto.List> content =
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch(); query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
Long total = Long total =
queryFactory queryFactory
.select(modelHyperParamEntity.count()) .select(modelHyperParamEntity.count())
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(builder) .where(builder)
.fetchOne(); .fetchOne();
long totalCount = (total != null) ? total : 0L; long totalCount = (total != null) ? total : 0L;
@@ -183,16 +184,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
} }
@Override @Override
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) { public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return Optional.ofNullable( return queryFactory
queryFactory .select(modelHyperParamEntity)
.select(modelHyperParamEntity) .from(modelHyperParamEntity)
.from(modelHyperParamEntity) .where(
.where( modelHyperParamEntity
modelHyperParamEntity .delYn.isFalse()
.delYn .and(modelHyperParamEntity.modelType.eq(modelType)))
.isFalse() .fetch();
.and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne());
} }
} }