hyperparam_with_modeltype

This commit is contained in:
2026-02-12 18:56:32 +09:00
parent 37d776dd2c
commit 0bc4453c9c
7 changed files with 107 additions and 117 deletions

View File

@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
private String desc; private String desc;
public static ModelType getValueData(String modelNo) { public static ModelType getValueData(String modelNo) {
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1); return Arrays.stream(ModelType.values())
.filter(m -> m.getId().equals(modelNo))
.findFirst()
.orElse(G1);
} }
@Override @Override

View File

@@ -182,10 +182,8 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("/init/{model}") @GetMapping("/init/{model}")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam( public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
@PathVariable ModelType model
) {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model)); return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
} }
} }

View File

@@ -79,7 +79,8 @@ public class ModelTrainMngApiController {
@DeleteMapping("/{uuid}") @DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteModelTrain( public ApiResponseDto<Void> deleteModelTrain(
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79") @Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
@PathVariable UUID uuid) { @PathVariable
UUID uuid) {
modelTrainMngService.deleteModelTrain(uuid); modelTrainMngService.deleteModelTrain(uuid);
return ApiResponseDto.ok(null); return ApiResponseDto.ok(null);
} }

View File

@@ -50,15 +50,15 @@ public class HyperParamCoreService {
/** /**
* 하이퍼파라미터 수정 * 하이퍼파라미터 수정
* *
* @param uuid uuid * @param uuid uuid
* @param createReq 등록 요청 * @param createReq 등록 요청
* @return ver * @return ver
*/ */
public String updateHyperParam(UUID uuid, HyperParam createReq) { public String updateHyperParam(UUID uuid, HyperParam createReq) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getIsDefault()) { if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY); throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
return entity.getHyperVer(); return entity.getHyperVer();
} }
/** /**
* 하이퍼파라미터 삭제 * 하이퍼파라미터 삭제
* *
@@ -80,15 +79,15 @@ public class HyperParamCoreService {
*/ */
public void deleteHyperParam(UUID uuid) { public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) { // if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); // throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// } // }
//디폴트면 삭제불가 // 디폴트면 삭제불가
if (entity.getIsDefault()) { if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
} }
@@ -105,9 +104,9 @@ public class HyperParamCoreService {
*/ */
public HyperParamDto.Basic getInitHyperParam(ModelType model) { public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.getHyperparamByType(model) .getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto(); return entity.toDto();
} }
@@ -118,9 +117,9 @@ public class HyperParamCoreService {
*/ */
public HyperParamDto.Basic getHyperParam(UUID uuid) { public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto(); return entity.toDto();
} }
@@ -143,16 +142,16 @@ public class HyperParamCoreService {
*/ */
public String getFirstHyperParamVersion(ModelType model) { public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository return hyperParamRepository
.findHyperParamVerByModelType(model) .findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer) .map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model)) .map(ver -> increase(ver, model))
.orElse(model.name() + "_000001"); .orElse(model.name() + "_000001");
} }
/** /**
* 하이퍼 파라미터의 버전을 증가시킨다. * 하이퍼 파라미터의 버전을 증가시킨다.
* *
* @param hyperVer 현재 버전 * @param hyperVer 현재 버전
* @param modelType 모델 타입 * @param modelType 모델 타입
* @return 증가된 버전 * @return 증가된 버전
*/ */
@@ -214,5 +213,4 @@ public class HyperParamCoreService {
// memo // memo
entity.setMemo(src.getMemo()); entity.setMemo(src.getMemo());
} }
} }

View File

@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
*/ */
public void deleteModel(UUID uuid) { public void deleteModel(UUID uuid) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
entity.setDelYn(true); entity.setDelYn(true);
entity.setUpdatedDttm(ZonedDateTime.now()); entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId()); entity.setUpdatedUid(userUtil.getId());
@@ -83,15 +83,15 @@ public class ModelTrainMngCoreService {
ModelMasterEntity entity = new ModelMasterEntity(); ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity(); ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
// 최적화 파라미터는 모델 type의 디폴트사용 // 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) { if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo()); ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null); hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); // hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else { } else {
hyperParamEntity = hyperParamEntity =
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null); hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
} }
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) { if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
@@ -99,8 +99,8 @@ public class ModelTrainMngCoreService {
} }
String modelVer = String modelVer =
String.join( String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString()); ".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
entity.setModelVer(modelVer); entity.setModelVer(modelVer);
entity.setHyperParamId(hyperParamEntity.getId()); entity.setHyperParamId(hyperParamEntity.getId());
entity.setModelNo(addReq.getModelNo()); entity.setModelNo(addReq.getModelNo());
@@ -132,7 +132,7 @@ public class ModelTrainMngCoreService {
* data set 저장 * data set 저장
* *
* @param modelId 저장한 모델 학습 id * @param modelId 저장한 모델 학습 id
* @param addReq 요청 파라미터 * @param addReq 요청 파라미터
*/ */
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) { public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
TrainingDataset dataset = addReq.getTrainingDataset(); TrainingDataset dataset = addReq.getTrainingDataset();
@@ -165,9 +165,9 @@ public class ModelTrainMngCoreService {
*/ */
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) { public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) { if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath()); entity.setRequestPath(req.getRequestPath());
} }
@@ -180,7 +180,7 @@ public class ModelTrainMngCoreService {
/** /**
* 모델 데이터셋 mapping 테이블 저장 * 모델 데이터셋 mapping 테이블 저장
* *
* @param modelId 모델학습 id * @param modelId 모델학습 id
* @param datasetList 선택한 data set * @param datasetList 선택한 data set
*/ */
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) { public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
@@ -197,7 +197,7 @@ public class ModelTrainMngCoreService {
* 모델학습 config 저장 * 모델학습 config 저장
* *
* @param modelId 모델학습 id * @param modelId 모델학습 id
* @param req 요청 파라미터 * @param req 요청 파라미터
* @return * @return
*/ */
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) { public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
@@ -217,7 +217,7 @@ public class ModelTrainMngCoreService {
/** /**
* 데이터셋 매핑 생성 * 데이터셋 매핑 생성
* *
* @param modelUid 모델 UID * @param modelUid 모델 UID
* @param datasetIds 데이터셋 ID 목록 * @param datasetIds 데이터셋 ID 목록
*/ */
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) { public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
@@ -239,8 +239,8 @@ public class ModelTrainMngCoreService {
public ModelMasterEntity findByUuid(UUID uuid) { public ModelMasterEntity findByUuid(UUID uuid) {
try { try {
return modelMngRepository return modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid); throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
} }
@@ -254,9 +254,9 @@ public class ModelTrainMngCoreService {
*/ */
public Long findModelIdByUuid(UUID uuid) { public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findByUuid(uuid) .findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId(); return entity.getId();
} }
@@ -269,8 +269,8 @@ public class ModelTrainMngCoreService {
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) { public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid); ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository return modelConfigRepository
.findModelConfigByModelId(modelEntity.getId()) .findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} }
/** /**
@@ -301,21 +301,19 @@ public class ModelTrainMngCoreService {
*/ */
public ModelTrainMngDto.Basic findModelById(Long id) { public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(id) .findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto(); return entity.toDto();
} }
/** /** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
*/
@Transactional @Transactional
public void markInProgress(Long modelId, Long jobId) { public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId); master.setCurrentAttemptId(jobId);
@@ -323,54 +321,46 @@ public class ModelTrainMngCoreService {
// 필요하면 시작시간도 여기서 찍어줌 // 필요하면 시작시간도 여기서 찍어줌
} }
/** /** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
*/
@Transactional @Transactional
public void clearLastError(Long modelId) { public void clearLastError(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null); master.setLastError(null);
} }
/** /** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
*/
@Transactional @Transactional
public void markStopped(Long modelId) { public void markStopped(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId()); master.setStatusCd(TrainStatusType.STOPPED.getId());
} }
/** /** 완료 처리(옵션) - Worker가 성공 시 호출 */
* 완료 처리(옵션) - Worker가 성공 시 호출
*/
@Transactional @Transactional
public void markCompleted(Long modelId) { public void markCompleted(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
} }
/** /** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markError(Long modelId, String errorMessage) { public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId()); master.setStep1State(TrainStatusType.ERROR.getId());
@@ -379,15 +369,13 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now()); master.setUpdatedDttm(ZonedDateTime.now());
} }
/** /** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markStep2Error(Long modelId, String errorMessage) { public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId()); master.setStep2State(TrainStatusType.ERROR.getId());
@@ -399,9 +387,9 @@ public class ModelTrainMngCoreService {
@Transactional @Transactional
public void markSuccess(Long modelId) { public void markSuccess(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리 // 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
@@ -429,9 +417,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) { public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now()); entity.setStep1StrtDttm(ZonedDateTime.now());
@@ -449,9 +437,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) { public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now()); entity.setStep2StrtDttm(ZonedDateTime.now());
@@ -469,9 +457,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) { public void markStep1Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId()); entity.setStep1State(TrainStatusType.COMPLETED.getId());
@@ -488,9 +476,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW) @Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) { public void markStep2Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId()); entity.setStep2State(TrainStatusType.COMPLETED.getId());
@@ -501,9 +489,9 @@ public class ModelTrainMngCoreService {
public void updateModelMasterBestEpoch(Long modelId, int epoch) { public void updateModelMasterBestEpoch(Long modelId, int epoch) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setBestEpoch(epoch); entity.setBestEpoch(epoch);
} }

View File

@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
@Enumerated(EnumType.STRING) @Enumerated(EnumType.STRING)
private ModelType modelType; private ModelType modelType;
@Column(name = "default_param") @Column(name = "default_param")
private Boolean isDefault = false; private Boolean isDefault = false;
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
// ------------------------- // -------------------------
this.gpuCnt, this.gpuCnt,
this.gpuIds, this.gpuIds,
this.masterPort this.masterPort,
, this.isDefault this.isDefault);
);
} }
} }

View File

@@ -185,10 +185,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
@Override @Override
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) { public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return Optional.ofNullable( return Optional.ofNullable(
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType))) .where(
.fetchOne()); modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne());
} }
} }