diff --git a/src/main/java/com/kamco/cd/training/common/enums/ModelType.java b/src/main/java/com/kamco/cd/training/common/enums/ModelType.java index 63d5e94..c78d45b 100644 --- a/src/main/java/com/kamco/cd/training/common/enums/ModelType.java +++ b/src/main/java/com/kamco/cd/training/common/enums/ModelType.java @@ -17,7 +17,10 @@ public enum ModelType implements EnumType { private String desc; public static ModelType getValueData(String modelNo) { - return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1); + return Arrays.stream(ModelType.values()) + .filter(m -> m.getId().equals(modelNo)) + .findFirst() + .orElse(G1); } @Override diff --git a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java index 827148e..ca2a64c 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java @@ -182,10 +182,8 @@ public class HyperParamApiController { @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @GetMapping("/init/{model}") - public ApiResponseDto getInitHyperParam( - @PathVariable ModelType model + public ApiResponseDto getInitHyperParam(@PathVariable ModelType model) { - ) { return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model)); } } diff --git a/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java index 0a81e96..c6fd192 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java @@ -79,7 +79,8 @@ public class ModelTrainMngApiController { @DeleteMapping("/{uuid}") public ApiResponseDto deleteModelTrain( @Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79") - @PathVariable UUID uuid) { + @PathVariable + UUID uuid) { modelTrainMngService.deleteModelTrain(uuid); return ApiResponseDto.ok(null); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java index 79ec4a8..c25c5c6 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java @@ -50,15 +50,15 @@ public class HyperParamCoreService { /** * 하이퍼파라미터 수정 * - * @param uuid uuid + * @param uuid uuid * @param createReq 등록 요청 * @return ver */ public String updateHyperParam(UUID uuid, HyperParam createReq) { ModelHyperParamEntity entity = - hyperParamRepository - .findHyperParamByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); if (entity.getIsDefault()) { throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY); @@ -72,7 +72,6 @@ public class HyperParamCoreService { return entity.getHyperVer(); } - /** * 하이퍼파라미터 삭제 * @@ -80,15 +79,15 @@ public class HyperParamCoreService { */ public void deleteHyperParam(UUID uuid) { ModelHyperParamEntity entity = - hyperParamRepository - .findHyperParamByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); -// if (entity.getHyperVer().equals("HPs_0001")) { -// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); -// } + // if (entity.getHyperVer().equals("HPs_0001")) { + // throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); + // } - //디폴트면 삭제불가 + // 디폴트면 삭제불가 if (entity.getIsDefault()) { throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY); } @@ -105,9 +104,9 @@ public class HyperParamCoreService { */ public HyperParamDto.Basic getInitHyperParam(ModelType model) { ModelHyperParamEntity entity = - hyperParamRepository - .getHyperparamByType(model) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + hyperParamRepository + .getHyperparamByType(model) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); return entity.toDto(); } @@ -118,9 +117,9 @@ public class HyperParamCoreService { */ public HyperParamDto.Basic getHyperParam(UUID uuid) { ModelHyperParamEntity entity = - hyperParamRepository - .findHyperParamByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); return entity.toDto(); } @@ -143,16 +142,16 @@ public class HyperParamCoreService { */ public String getFirstHyperParamVersion(ModelType model) { return hyperParamRepository - .findHyperParamVerByModelType(model) - .map(ModelHyperParamEntity::getHyperVer) - .map(ver -> increase(ver, model)) - .orElse(model.name() + "_000001"); + .findHyperParamVerByModelType(model) + .map(ModelHyperParamEntity::getHyperVer) + .map(ver -> increase(ver, model)) + .orElse(model.name() + "_000001"); } /** * 하이퍼 파라미터의 버전을 증가시킨다. * - * @param hyperVer 현재 버전 + * @param hyperVer 현재 버전 * @param modelType 모델 타입 * @return 증가된 버전 */ @@ -214,5 +213,4 @@ public class HyperParamCoreService { // memo entity.setMemo(src.getMemo()); } - } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index cc0a0ef..273f876 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -65,9 +65,9 @@ public class ModelTrainMngCoreService { */ public void deleteModel(UUID uuid) { ModelMasterEntity entity = - modelMngRepository - .findByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + modelMngRepository + .findByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); entity.setDelYn(true); entity.setUpdatedDttm(ZonedDateTime.now()); entity.setUpdatedUid(userUtil.getId()); @@ -83,15 +83,15 @@ public class ModelTrainMngCoreService { ModelMasterEntity entity = new ModelMasterEntity(); ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity(); -// 최적화 파라미터는 모델 type의 디폴트사용 + // 최적화 파라미터는 모델 type의 디폴트사용 if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) { ModelType modelType = ModelType.getValueData(addReq.getModelNo()); hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null); -// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); + // hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); } else { hyperParamEntity = - hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null); + hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null); } if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) { @@ -99,8 +99,8 @@ public class ModelTrainMngCoreService { } String modelVer = - String.join( - ".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString()); + String.join( + ".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString()); entity.setModelVer(modelVer); entity.setHyperParamId(hyperParamEntity.getId()); entity.setModelNo(addReq.getModelNo()); @@ -132,7 +132,7 @@ public class ModelTrainMngCoreService { * data set 저장 * * @param modelId 저장한 모델 학습 id - * @param addReq 요청 파라미터 + * @param addReq 요청 파라미터 */ public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) { TrainingDataset dataset = addReq.getTrainingDataset(); @@ -165,9 +165,9 @@ public class ModelTrainMngCoreService { */ public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) { entity.setRequestPath(req.getRequestPath()); } @@ -180,7 +180,7 @@ public class ModelTrainMngCoreService { /** * 모델 데이터셋 mapping 테이블 저장 * - * @param modelId 모델학습 id + * @param modelId 모델학습 id * @param datasetList 선택한 data set */ public void saveModelDatasetMap(Long modelId, List datasetList) { @@ -197,7 +197,7 @@ public class ModelTrainMngCoreService { * 모델학습 config 저장 * * @param modelId 모델학습 id - * @param req 요청 파라미터 + * @param req 요청 파라미터 * @return */ public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) { @@ -217,7 +217,7 @@ public class ModelTrainMngCoreService { /** * 데이터셋 매핑 생성 * - * @param modelUid 모델 UID + * @param modelUid 모델 UID * @param datasetIds 데이터셋 ID 목록 */ public void createDatasetMappings(Long modelUid, List datasetIds) { @@ -239,8 +239,8 @@ public class ModelTrainMngCoreService { public ModelMasterEntity findByUuid(UUID uuid) { try { return modelMngRepository - .findByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + .findByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); } catch (IllegalArgumentException e) { throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid); } @@ -254,9 +254,9 @@ public class ModelTrainMngCoreService { */ public Long findModelIdByUuid(UUID uuid) { ModelMasterEntity entity = - modelMngRepository - .findByUuid(uuid) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + modelMngRepository + .findByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); return entity.getId(); } @@ -269,8 +269,8 @@ public class ModelTrainMngCoreService { public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) { ModelMasterEntity modelEntity = findByUuid(uuid); return modelConfigRepository - .findModelConfigByModelId(modelEntity.getId()) - .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + .findModelConfigByModelId(modelEntity.getId()) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); } /** @@ -301,21 +301,19 @@ public class ModelTrainMngCoreService { */ public ModelTrainMngDto.Basic findModelById(Long id) { ModelMasterEntity entity = - modelMngRepository - .findById(id) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + id)); + modelMngRepository + .findById(id) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + id)); return entity.toDto(); } - /** - * 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 - */ + /** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */ @Transactional public void markInProgress(Long modelId, Long jobId) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); master.setCurrentAttemptId(jobId); @@ -323,54 +321,46 @@ public class ModelTrainMngCoreService { // 필요하면 시작시간도 여기서 찍어줌 } - /** - * 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 - */ + /** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */ @Transactional public void clearLastError(Long modelId) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setLastError(null); } - /** - * 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 - */ + /** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */ @Transactional public void markStopped(Long modelId) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.STOPPED.getId()); } - /** - * 완료 처리(옵션) - Worker가 성공 시 호출 - */ + /** 완료 처리(옵션) - Worker가 성공 시 호출 */ @Transactional public void markCompleted(Long modelId) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.COMPLETED.getId()); } - /** - * step 1오류 처리(옵션) - Worker가 실패 시 호출 - */ + /** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ @Transactional public void markError(Long modelId, String errorMessage) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStep1State(TrainStatusType.ERROR.getId()); @@ -379,15 +369,13 @@ public class ModelTrainMngCoreService { master.setUpdatedDttm(ZonedDateTime.now()); } - /** - * step 2오류 처리(옵션) - Worker가 실패 시 호출 - */ + /** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ @Transactional public void markStep2Error(Long modelId, String errorMessage) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStep2State(TrainStatusType.ERROR.getId()); @@ -399,9 +387,9 @@ public class ModelTrainMngCoreService { @Transactional public void markSuccess(Long modelId) { ModelMasterEntity master = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); // 모델 상태 완료 처리 master.setStatusCd(TrainStatusType.COMPLETED.getId()); @@ -429,9 +417,9 @@ public class ModelTrainMngCoreService { @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep1InProgress(Long modelId, Long jobId) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStep1StrtDttm(ZonedDateTime.now()); @@ -449,9 +437,9 @@ public class ModelTrainMngCoreService { @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep2InProgress(Long modelId, Long jobId) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStep2StrtDttm(ZonedDateTime.now()); @@ -469,9 +457,9 @@ public class ModelTrainMngCoreService { @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep1Success(Long modelId) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStep1State(TrainStatusType.COMPLETED.getId()); @@ -488,9 +476,9 @@ public class ModelTrainMngCoreService { @Transactional(propagation = Propagation.REQUIRES_NEW) public void markStep2Success(Long modelId) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); entity.setStatusCd(TrainStatusType.COMPLETED.getId()); entity.setStep2State(TrainStatusType.COMPLETED.getId()); @@ -501,9 +489,9 @@ public class ModelTrainMngCoreService { public void updateModelMasterBestEpoch(Long modelId, int epoch) { ModelMasterEntity entity = - modelMngRepository - .findById(modelId) - .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); entity.setBestEpoch(epoch); } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java index ec18265..7f0f215 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java @@ -316,7 +316,6 @@ public class ModelHyperParamEntity { @Enumerated(EnumType.STRING) private ModelType modelType; - @Column(name = "default_param") private Boolean isDefault = false; @@ -395,8 +394,7 @@ public class ModelHyperParamEntity { // ------------------------- this.gpuCnt, this.gpuIds, - this.masterPort - , this.isDefault - ); + this.masterPort, + this.isDefault); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java index eed1036..80a7d14 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java @@ -185,10 +185,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom { @Override public Optional getHyperparamByType(ModelType modelType) { return Optional.ofNullable( - queryFactory - .select(modelHyperParamEntity) - .from(modelHyperParamEntity) - .where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType))) - .fetchOne()); + queryFactory + .select(modelHyperParamEntity) + .from(modelHyperParamEntity) + .where( + modelHyperParamEntity + .delYn + .isFalse() + .and(modelHyperParamEntity.modelType.eq(modelType))) + .fetchOne()); } }