diff --git a/gradlew b/gradlew old mode 100644 new mode 100755 diff --git a/src/main/java/com/kamco/cd/training/config/GlobalExceptionHandler.java b/src/main/java/com/kamco/cd/training/config/GlobalExceptionHandler.java index 6148346..ffc8542 100644 --- a/src/main/java/com/kamco/cd/training/config/GlobalExceptionHandler.java +++ b/src/main/java/com/kamco/cd/training/config/GlobalExceptionHandler.java @@ -482,37 +482,40 @@ public class GlobalExceptionHandler { @ExceptionHandler(CustomApiException.class) public ApiResponseDto handleCustomApiException( CustomApiException e, HttpServletRequest request) { + log.warn("[CustomApiException] resource : {}", e.getMessage()); this.errorLog(request, e); + String codeName = e.getCodeName(); HttpStatus status = e.getStatus(); - // String message = e.getMessage() == null ? ApiResponseCode.getMessage(codeName) : - // e.getMessage(); - // - // ApiResponseCode apiCode = ApiResponseCode.getCode(codeName); - // - // ErrorLogEntity errorLog = - // saveErrorLogData( - // request, apiCode, status, ErrorLogDto.LogErrorLevel.WARNING, e.getStackTrace()); - // - // ApiResponseDto body = - // ApiResponseDto.createException(apiCode, message, status, errorLog.getId()); + + // enum에 없는 codeName이면 IllegalArgumentException 방지 + ApiResponseDto.ApiResponseCode apiCode = null; + try { + apiCode = ApiResponseDto.ApiResponseCode.getCode(codeName); + } catch (Exception ignore) { + apiCode = null; + } + + // 메시지 우선순위: 예외 message > enum message > fallback + String message; + if (e.getMessage() != null && !e.getMessage().isBlank()) { + message = e.getMessage(); + } else if (apiCode != null) { + message = apiCode.getText(); // enum 기본 메시지 + } else { + message = "서버에 문제가 발생 하였습니다."; // 최후 fallback + } + + // 로그/응답 코드: enum 없으면 기본코드로 + ApiResponseDto.ApiResponseCode safeCode = + (apiCode != null) ? apiCode : ApiResponseDto.ApiResponseCode.INTERNAL_SERVER_ERROR; ErrorLogEntity errorLog = saveErrorLogData( - request, - ApiResponseCode.getCode(codeName), - HttpStatus.valueOf(status.value()), - ErrorLogDto.LogErrorLevel.WARNING, - e.getStackTrace()); + request, safeCode, status, ErrorLogDto.LogErrorLevel.WARNING, e.getStackTrace()); - return ApiResponseDto.createException( - ApiResponseCode.getCode(codeName), - ApiResponseCode.getMessage(codeName), - HttpStatus.valueOf(status.value()), - errorLog.getId()); - - // return new ResponseEntity<>(body, status); + return ApiResponseDto.createException(safeCode, message, status, errorLog.getId()); } private void errorLog(HttpServletRequest request, Throwable e) { diff --git a/src/main/java/com/kamco/cd/training/config/api/ApiResponseDto.java b/src/main/java/com/kamco/cd/training/config/api/ApiResponseDto.java index b9dc41b..886031c 100644 --- a/src/main/java/com/kamco/cd/training/config/api/ApiResponseDto.java +++ b/src/main/java/com/kamco/cd/training/config/api/ApiResponseDto.java @@ -189,7 +189,12 @@ public class ApiResponseDto { } public static ApiResponseCode getCode(String name) { - return ApiResponseCode.valueOf(name.toUpperCase()); + if (name == null || name.isBlank()) return null; + try { + return ApiResponseCode.valueOf(name.toUpperCase()); + } catch (IllegalArgumentException ex) { + return null; + } } public static String getMessage(String name) { diff --git a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java new file mode 100644 index 0000000..758e841 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java @@ -0,0 +1,141 @@ +package com.kamco.cd.training.hyperparam; + +import com.kamco.cd.training.config.api.ApiResponseDto; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; +import com.kamco.cd.training.hyperparam.service.HyperParamService; +import com.kamco.cd.training.model.dto.ModelMngDto; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.responses.ApiResponses; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.validation.Valid; +import java.time.LocalDate; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.data.domain.Page; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +@Tag(name = "하이퍼파라미터 관리", description = "하이퍼파라미터 관리 API") +@RestController +@RequiredArgsConstructor +@RequestMapping("/api/hyper-param") +public class HyperParamApiController { + + private final HyperParamService hyperParamService; + + @Operation(summary = "하이퍼파라미터 등록", description = "파라미터를 신규 저장") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "등록 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping + public ApiResponseDto createHyperParam( + @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) { + String newVersion = hyperParamService.createHyperParam(createReq); + return ApiResponseDto.ok(newVersion); + } + + @Operation(summary = "하이퍼파라미터 수정", description = "파라미터를 수정") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "등록 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PutMapping("/{uuid}") + public ApiResponseDto updateHyperParam( + @PathVariable UUID uuid, @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) { + return ApiResponseDto.ok(hyperParamService.updateHyperParam(uuid, createReq)); + } + + @Operation(summary = "하이퍼파라미터 목록 조회", description = "하이퍼파라미터 목록 조회") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))), + @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/list") + public ApiResponseDto> getHyperParam( + @Parameter(description = "구분") @RequestParam(required = false) String type, + @Parameter(description = "시작일") @RequestParam(required = false) LocalDate startDate, + @Parameter(description = "종료일") @RequestParam(required = false) LocalDate endDate, + @Parameter(description = "버전명") @RequestParam(required = false) String hyperVer, + @Parameter( + description = "정렬", + example = "createdDttm desc", + schema = + @Schema( + allowableValues = { + "createdDttm,desc", + "lastUsedDttm,desc", + "totalUseCnt,desc" + })) + @RequestParam(required = false) + String sort, + @Parameter(description = "페이지 번호 (0부터 시작)", example = "0") @RequestParam(defaultValue = "0") + int page, + @Parameter(description = "페이지 크기", example = "20") @RequestParam(defaultValue = "20") + int size) { + HyperParamDto.SearchReq searchReq = new HyperParamDto.SearchReq(); + searchReq.setType(type); + searchReq.setStartDate(startDate); + searchReq.setEndDate(endDate); + searchReq.setHyperVer(hyperVer); + searchReq.setSort(sort); + searchReq.setPage(page); + searchReq.setSize(size); + Page list = hyperParamService.getHyperParamList(searchReq); + + return ApiResponseDto.ok(list); + } + + @Operation(summary = "하이퍼파라미터 삭제", description = "하이퍼파라미터 삭제") + @ApiResponses( + value = { + @ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content), + @ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content), + @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @DeleteMapping("/{uuid}") + public ApiResponseDto deleteHyperParam( + @Parameter(description = "하이퍼파라미터 uuid", example = "7966dd64-004a-4596-89ef-001664bc4de2") + @PathVariable + UUID uuid) { + hyperParamService.deleteHyperParam(uuid); + return ApiResponseDto.ok(null); + } +} diff --git a/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java new file mode 100644 index 0000000..6824e60 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java @@ -0,0 +1,303 @@ +package com.kamco.cd.training.hyperparam.dto; + +import com.kamco.cd.training.common.utils.enums.CodeExpose; +import com.kamco.cd.training.common.utils.enums.EnumType; +import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; +import io.swagger.v3.oas.annotations.media.Schema; +import java.time.LocalDate; +import java.time.ZonedDateTime; +import java.util.UUID; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; + +public class HyperParamDto { + + @Schema(name = "Basic", description = "하이퍼파라미터 조회") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class Basic { + private Long id; + private UUID uuid; + private String hyperVer; + + // ------------------------- + // Important + // ------------------------- + private String backbone; + private String inputSize; + private String cropSize; + private Integer epochCnt; + private Integer batchSize; + + // ------------------------- + // Model Architecture + // ------------------------- + private Double dropPathRate; + private Integer frozenStages; + private String neckPolicy; + private String decoderChannels; + private String classWeight; + private Integer numLayers; + + // ------------------------- + // Loss & Optimization + // ------------------------- + private Double learningRate; + private Double weightDecay; + private Double layerDecayRate; + private Boolean ddpFindUnusedParams; + private Integer ignoreIndex; + + // ------------------------- + // Data + // ------------------------- + private Integer trainNumWorkers; + private Integer valNumWorkers; + private Integer testNumWorkers; + private Boolean trainShuffle; + private Boolean trainPersistent; + private Boolean valPersistent; + + // ------------------------- + // Evaluation + // ------------------------- + private String metrics; + private String saveBest; + private String saveBestRule; + private Integer valInterval; + private Integer logInterval; + private Integer visInterval; + + // ------------------------- + // Augmentation + // ------------------------- + private Double rotProb; + private Double flipProb; + private String rotDegree; + private Double exchangeProb; + private Integer brightnessDelta; + private String contrastRange; + private String saturationRange; + private Integer hueDelta; + + // ------------------------- + // Hardware + // ------------------------- + private Integer gpuCnt; + private String gpuIds; + private Integer masterPort; + + // ------------------------- + // Common + // ------------------------- + private String memo; + private String delYn; + + @JsonFormatDttm private ZonedDateTime createdDttm; + private Long createdUid; + @JsonFormatDttm private ZonedDateTime updatedDttm; + private Long updatedUid; + @JsonFormatDttm private ZonedDateTime lastUsedDttm; + + private Long m1UseCnt; + private Long m2UseCnt; + private Long m3UseCnt; + } + + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class List { + private UUID uuid; + private String hyperVer; + @JsonFormatDttm private ZonedDateTime createDttm; + @JsonFormatDttm private ZonedDateTime lastUsedDttm; + private Long m1UseCnt; + private Long m2UseCnt; + private Long m3UseCnt; + private Long totalCnt; + } + + @Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class HyperParamCreateReq { + + @Schema(description = "백본 네트워크", example = "large") + private String backbone; // backbone + + @Schema(description = "입력 이미지 크기(H,W)", example = "256,256") + private String inputSize; // input_size + + @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") + private String cropSize; // crop_size + + @Schema(description = "총 학습 에폭 수", example = "200") + private Integer epochCnt; // epoch_cnt + + @Schema(description = "배치 크기(Per GPU)", example = "16") + private Integer batchSize; // batch_size + + @Schema(description = "Drop Path 비율", example = "0.3") + private Double dropPathRate; // drop_path_rate + + @Schema(description = "Freeze 단계(-1:None)", example = "-1") + private Integer frozenStages; // frozen_stages + + @Schema(description = "Neck 결합 정책", example = "abs_diff") + private String neckPolicy; // neck_policy + + @Schema(description = "디코더 채널 구성", example = "512,256,128,64") + private String decoderChannels; // decoder_channels + + @Schema(description = "클래스별 가중치", example = "1,10") + private String classWeight; // class_weight + + @Schema(description = "레이어 깊이", example = "24") + private Integer numLayers; // num_layers + + @Schema(description = "학습률", example = "0.00006") + private Double learningRate; // learning_rate + + @Schema(description = "Weight Decay", example = "0.05") + private Double weightDecay; // weight_decay + + @Schema(description = "Layer Decay Rate", example = "0.9") + private Double layerDecayRate; // layer_decay_rate + + @Schema(description = "DDP unused params 탐색 여부", example = "true") + private Boolean ddpFindUnusedParams; // ddp_find_unused_params + + @Schema(description = "Loss 계산 제외 인덱스", example = "255") + private Integer ignoreIndex; // ignore_index + + @Schema(description = "Train dataloader workers", example = "16") + private Integer trainNumWorkers; // train_num_workers + + @Schema(description = "Val dataloader workers", example = "8") + private Integer valNumWorkers; // val_num_workers + + @Schema(description = "Test dataloader workers", example = "8") + private Integer testNumWorkers; // test_num_workers + + @Schema(description = "Train shuffle 여부", example = "true") + private Boolean trainShuffle; // train_shuffle + + @Schema(description = "Train persistent workers 여부", example = "true") + private Boolean trainPersistent; // train_persistent + + @Schema(description = "Val persistent workers 여부", example = "true") + private Boolean valPersistent; // val_persistent + + @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") + private String metrics; // metrics + + @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") + private String saveBest; // save_best + + @Schema(description = "Best 모델 선정 규칙", example = "greater") + private String saveBestRule; // save_best_rule + + @Schema(description = "검증 수행 주기(Epoch)", example = "10") + private Integer valInterval; // val_interval + + @Schema(description = "로그 기록 주기(Iteration)", example = "400") + private Integer logInterval; // log_interval + + @Schema(description = "시각화 저장 주기(Epoch)", example = "1") + private Integer visInterval; // vis_interval + + @Schema(description = "회전 적용 확률", example = "0.5") + private Double rotProb; // rot_prob + + @Schema(description = "반전 적용 확률", example = "0.5") + private Double flipProb; // flip_prob + + @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") + private String rotDegree; // rot_degree + + @Schema(description = "채널 교환 확률", example = "0.5") + private Double exchangeProb; // exchange_prob + + @Schema(description = "밝기 변화량", example = "10") + private Integer brightnessDelta; // brightness_delta + + @Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2") + private String contrastRange; // contrast_range + + @Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2") + private String saturationRange; // saturation_range + + @Schema(description = "색조 변화량", example = "10") + private Integer hueDelta; // hue_delta + + @Schema(description = "사용 GPU 개수", example = "4") + private Integer gpuCnt; // gpu_cnt + + @Schema(description = "사용 GPU ID 목록", example = "0,1,2,3") + private String gpuIds; // gpu_ids + + @Schema(description = "분산학습 마스터 포트", example = "1122") + private Integer masterPort; // master_port + + @Schema(description = "메모", example = "하이퍼파라미터 신규등록") + private String memo; // memo + } + + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class SearchReq { + private String type; + private LocalDate startDate; + private LocalDate endDate; + private String hyperVer; + + // 페이징 파라미터 + private int page = 0; + private int size = 20; + private String sort; + + public Pageable toPageable() { + if (sort != null && !sort.isEmpty()) { + String[] sortParams = sort.split(","); + String property = sortParams[0]; + Sort.Direction direction = + sortParams.length > 1 ? Sort.Direction.fromString(sortParams[1]) : Sort.Direction.ASC; + return PageRequest.of(page, size, Sort.by(direction, property)); + } + return PageRequest.of(page, size); + } + } + + @CodeExpose + @Getter + @AllArgsConstructor + public enum HyperType implements EnumType { + CREATE_DATE("생성일"), + LAST_USED_DATE("최근 사용일"); + + private final String desc; + + @Override + public String getId() { + return name(); + } + + @Override + public String getText() { + return desc; + } + } +} diff --git a/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java new file mode 100644 index 0000000..6574a6a --- /dev/null +++ b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java @@ -0,0 +1,59 @@ +package com.kamco.cd.training.hyperparam.service; + +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; +import com.kamco.cd.training.postgres.core.HyperParamCoreService; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.data.domain.Page; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@Transactional(readOnly = true) +@RequiredArgsConstructor +public class HyperParamService { + + private final HyperParamCoreService hyperParamCoreService; + + /** + * 하이퍼 파라미터 목록 조회 + * + * @param req + * @return 목록 + */ + public Page getHyperParamList(HyperParamDto.SearchReq req) { + return hyperParamCoreService.findByHyperVerList(req); + } + + /** + * 하이퍼파라미터 등록 + * + * @param createReq 등록 요청 + * @return 생성된 버전명 + */ + @Transactional + public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) { + return hyperParamCoreService.createHyperParam(createReq); + } + + /** + * 하이퍼파라미터 수정 + * + * @param createReq + * @return + */ + @Transactional + public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) { + return hyperParamCoreService.updateHyperParam(uuid, createReq); + } + + /** + * 하이퍼파라미터 삭제 + * + * @param uuid + */ + public void deleteHyperParam(UUID uuid) { + hyperParamCoreService.deleteHyperParam(uuid); + } +} diff --git a/src/main/java/com/kamco/cd/training/members/dto/SignInRequest.java b/src/main/java/com/kamco/cd/training/members/dto/SignInRequest.java index f4b6571..b40b86c 100644 --- a/src/main/java/com/kamco/cd/training/members/dto/SignInRequest.java +++ b/src/main/java/com/kamco/cd/training/members/dto/SignInRequest.java @@ -11,10 +11,10 @@ import lombok.ToString; @ToString(exclude = "password") public class SignInRequest { - @Schema(description = "사용자 ID", example = "1234567") + @Schema(description = "사용자 ID", example = "123456") private String username; - @Schema(description = "비밀번호", example = "Admin2!@#") + @Schema(description = "비밀번호", example = "qwe123!@#") @JsonProperty(access = JsonProperty.Access.WRITE_ONLY) private String password; } diff --git a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java b/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java index 286e302..cc19ac3 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java @@ -112,26 +112,6 @@ public class ModelMngApiController { return ApiResponseDto.ok(modelTrainService.getFormConfig()); } - @Operation(summary = "하이퍼파라미터 등록", description = "Step 1 에서 파라미터를 수정하여 신규 버전으로 저장합니다") - @ApiResponses( - value = { - @ApiResponse( - responseCode = "200", - description = "등록 성공", - content = - @Content( - mediaType = "application/json", - schema = @Schema(implementation = String.class))), - @ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content), - @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) - }) - @PostMapping("/hyper-params") - public ApiResponseDto createHyperParam( - @Valid @RequestBody ModelMngDto.HyperParamCreateReq createReq) { - String newVersion = modelTrainService.createHyperParam(createReq); - return ApiResponseDto.ok(newVersion); - } - @Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다") @ApiResponses( value = { @@ -151,22 +131,6 @@ public class ModelMngApiController { return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer)); } - @Operation(summary = "하이퍼파라미터 삭제", description = "특정 버전의 하이퍼파라미터를 삭제합니다 (H1은 삭제 불가)") - @ApiResponses( - value = { - @ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content), - @ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content), - @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), - @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) - }) - @DeleteMapping("/hyper-params/{hyperVer}") - public ApiResponseDto deleteHyperParam( - @Parameter(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") @PathVariable - String hyperVer) { - modelTrainService.deleteHyperParam(hyperVer); - return ApiResponseDto.ok(null); - } - @Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다") @ApiResponses( value = { diff --git a/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java b/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java deleted file mode 100644 index 84a0869..0000000 --- a/src/main/java/com/kamco/cd/training/model/dto/HyperParamDto.java +++ /dev/null @@ -1,218 +0,0 @@ -package com.kamco.cd.training.model.dto; - -import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; -import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; -import io.swagger.v3.oas.annotations.media.Schema; -import jakarta.validation.constraints.NotBlank; -import java.time.ZonedDateTime; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; - -public class HyperParamDto { - - @Schema(name = "HyperParam Basic", description = "하이퍼파라미터 기본 정보") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public static class Basic { - private String hyperVer; - - // Important - private String backbone; - private String inputSize; - private String cropSize; - private Integer epochCnt; - private Integer batchSize; - - // Architecture - private Double dropPathRate; - private Integer frozenStages; - private String neckPolicy; - private String decoderChannels; - private String classWeight; - private Integer numLayers; - - // Optimization - private Double learningRate; - private Double weightDecay; - private Double layerDecayRate; - private Boolean ddpFindUnusedParams; - private Integer ignoreIndex; - - // Data - private Integer trainNumWorkers; - private Integer valNumWorkers; - private Integer testNumWorkers; - private Boolean trainShuffle; - private Boolean trainPersistent; - private Boolean valPersistent; - - // Evaluation - private String metrics; - private String saveBest; - private String saveBestRule; - private Integer valInterval; - private Integer logInterval; - private Integer visInterval; - - // Hardware - private Integer gpuCnt; - private String gpuIds; - private Integer masterPort; - - // Augmentation - private Double rotProb; - private Double flipProb; - private String rotDegree; - private Double exchangeProb; - private Integer brightnessDelta; - private String contrastRange; - private String saturationRange; - private Integer hueDelta; - - // Legacy (deprecated) - private Double dropoutRatio; - private Integer cnnFilterCnt; - - // Common - private String memo; - @JsonFormatDttm private ZonedDateTime createdDttm; - - public Basic(ModelHyperParamEntity entity) { - this.hyperVer = entity.getHyperVer(); - - // Important - this.backbone = entity.getBackbone(); - this.inputSize = entity.getInputSize(); - this.cropSize = entity.getCropSize(); - this.epochCnt = entity.getEpochCnt(); - this.batchSize = entity.getBatchSize(); - - // Architecture - this.dropPathRate = entity.getDropPathRate(); - this.frozenStages = entity.getFrozenStages(); - this.neckPolicy = entity.getNeckPolicy(); - this.decoderChannels = entity.getDecoderChannels(); - this.classWeight = entity.getClassWeight(); - this.numLayers = entity.getNumLayers(); - - // Optimization - this.learningRate = entity.getLearningRate(); - this.weightDecay = entity.getWeightDecay(); - this.layerDecayRate = entity.getLayerDecayRate(); - this.ddpFindUnusedParams = entity.getDdpFindUnusedParams(); - this.ignoreIndex = entity.getIgnoreIndex(); - - // Data - this.trainNumWorkers = entity.getTrainNumWorkers(); - this.valNumWorkers = entity.getValNumWorkers(); - this.testNumWorkers = entity.getTestNumWorkers(); - this.trainShuffle = entity.getTrainShuffle(); - this.trainPersistent = entity.getTrainPersistent(); - this.valPersistent = entity.getValPersistent(); - - // Evaluation - this.metrics = entity.getMetrics(); - this.saveBest = entity.getSaveBest(); - this.saveBestRule = entity.getSaveBestRule(); - this.valInterval = entity.getValInterval(); - this.logInterval = entity.getLogInterval(); - this.visInterval = entity.getVisInterval(); - - // Hardware - this.gpuCnt = entity.getGpuCnt(); - this.gpuIds = entity.getGpuIds(); - this.masterPort = entity.getMasterPort(); - - // Augmentation - this.rotProb = entity.getRotProb(); - this.flipProb = entity.getFlipProb(); - this.rotDegree = entity.getRotDegree(); - this.exchangeProb = entity.getExchangeProb(); - this.brightnessDelta = entity.getBrightnessDelta(); - this.contrastRange = entity.getContrastRange(); - this.saturationRange = entity.getSaturationRange(); - this.hueDelta = entity.getHueDelta(); - - // Legacy - this.cnnFilterCnt = entity.getCnnFilterCnt(); - - // Common - this.memo = entity.getMemo(); - this.createdDttm = entity.getCreatedDttm(); - } - } - - @Schema(name = "HyperParam AddReq", description = "하이퍼파라미터 등록 요청") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public static class AddReq { - @NotBlank(message = "버전명은 필수입니다") - private String hyperVer; - - // Important - private String backbone; - private String inputSize; - private String cropSize; - private Integer epochCnt; - private Integer batchSize; - - // Architecture - private Double dropPathRate; - private Integer frozenStages; - private String neckPolicy; - private String decoderChannels; - private String classWeight; - private Integer numLayers; - - // Optimization - private Double learningRate; - private Double weightDecay; - private Double layerDecayRate; - private Boolean ddpFindUnusedParams; - private Integer ignoreIndex; - - // Data - private Integer trainNumWorkers; - private Integer valNumWorkers; - private Integer testNumWorkers; - private Boolean trainShuffle; - private Boolean trainPersistent; - private Boolean valPersistent; - - // Evaluation - private String metrics; - private String saveBest; - private String saveBestRule; - private Integer valInterval; - private Integer logInterval; - private Integer visInterval; - - // Hardware - private Integer gpuCnt; - private String gpuIds; - private Integer masterPort; - - // Augmentation - private Double rotProb; - private Double flipProb; - private String rotDegree; - private Double exchangeProb; - private Integer brightnessDelta; - private String contrastRange; - private String saturationRange; - private Integer hueDelta; - - // Legacy (deprecated) - private Double dropoutRatio; - private Integer cnnFilterCnt; - - // Common - private String memo; - } -} diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelMngDto.java index 41c8558..700ac8f 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelMngDto.java @@ -5,7 +5,6 @@ import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotNull; -import java.time.LocalDate; import java.time.ZonedDateTime; import java.util.List; import java.util.Map; @@ -303,155 +302,6 @@ public class ModelMngDto { @JsonFormatDttm private ZonedDateTime createdDttm; } - @Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public class HyperParamCreateReq { - - @Schema(description = "새로운 하이파라미터 버전", example = "") - private String newHyperVer; - - @Schema(description = "불러온 하이파라미터 버전", example = "") - private String baseHyperVer; - - @Schema(description = "하이퍼파라미터 PK", example = "1") - private Long hyperParamId; // hyper_param_id - - @Schema(description = "하이퍼파라미터 UUID", example = "3fa85f64-5717-4562-b3fc-2c963f66afa6") - private UUID uuid; // uuid (또는 hyper_param_uuid 컬럼이면 이름 맞춰 주세요) - - @Schema(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") - private String hyperVer; // hyper_ver - - @Schema(description = "백본 네트워크", example = "large") - private String backbone; // backbone - - @Schema(description = "입력 이미지 크기(H,W)", example = "256,256") - private String inputSize; // input_size - - @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") - private String cropSize; // crop_size - - @Schema(description = "총 학습 에폭 수", example = "200") - private Integer epochCnt; // epoch_cnt - - @Schema(description = "배치 크기(Per GPU)", example = "16") - private Integer batchSize; // batch_size - - @Schema(description = "Drop Path 비율", example = "0.3") - private Double dropPathRate; // drop_path_rate - - @Schema(description = "Freeze 단계(-1:None)", example = "-1") - private Integer frozenStages; // frozen_stages - - @Schema(description = "Neck 결합 정책", example = "abs_diff") - private String neckPolicy; // neck_policy - - @Schema(description = "디코더 채널 구성", example = "512,256,128,64") - private String decoderChannels; // decoder_channels - - @Schema(description = "클래스별 가중치", example = "1,10") - private String classWeight; // class_weight - - @Schema(description = "레이어 깊이", example = "24") - private Integer numLayers; // num_layers - - @Schema(description = "학습률", example = "0.00006") - private Double learningRate; // learning_rate - - @Schema(description = "Weight Decay", example = "0.05") - private Double weightDecay; // weight_decay - - @Schema(description = "Layer Decay Rate", example = "0.9") - private Double layerDecayRate; // layer_decay_rate - - @Schema(description = "DDP unused params 탐색 여부", example = "true") - private Boolean ddpFindUnusedParams; // ddp_find_unused_params - - @Schema(description = "Loss 계산 제외 인덱스", example = "255") - private Integer ignoreIndex; // ignore_index - - @Schema(description = "Train dataloader workers", example = "16") - private Integer trainNumWorkers; // train_num_workers - - @Schema(description = "Val dataloader workers", example = "8") - private Integer valNumWorkers; // val_num_workers - - @Schema(description = "Test dataloader workers", example = "8") - private Integer testNumWorkers; // test_num_workers - - @Schema(description = "Train shuffle 여부", example = "true") - private Boolean trainShuffle; // train_shuffle - - @Schema(description = "Train persistent workers 여부", example = "true") - private Boolean trainPersistent; // train_persistent - - @Schema(description = "Val persistent workers 여부", example = "true") - private Boolean valPersistent; // val_persistent - - @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") - private String metrics; // metrics - - @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") - private String saveBest; // save_best - - @Schema(description = "Best 모델 선정 규칙", example = "greater") - private String saveBestRule; // save_best_rule - - @Schema(description = "검증 수행 주기(Epoch)", example = "10") - private Integer valInterval; // val_interval - - @Schema(description = "로그 기록 주기(Iteration)", example = "400") - private Integer logInterval; // log_interval - - @Schema(description = "시각화 저장 주기(Epoch)", example = "1") - private Integer visInterval; // vis_interval - - @Schema(description = "회전 적용 확률", example = "0.5") - private Double rotProb; // rot_prob - - @Schema(description = "반전 적용 확률", example = "0.5") - private Double flipProb; // flip_prob - - @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") - private String rotDegree; // rot_degree - - @Schema(description = "채널 교환 확률", example = "0.5") - private Double exchangeProb; // exchange_prob - - @Schema(description = "밝기 변화량", example = "10") - private Integer brightnessDelta; // brightness_delta - - @Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2") - private String contrastRange; // contrast_range - - @Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2") - private String saturationRange; // saturation_range - - @Schema(description = "색조 변화량", example = "10") - private Integer hueDelta; // hue_delta - - @Schema(description = "사용 GPU 개수", example = "4") - private Integer gpuCnt; // gpu_cnt - - @Schema(description = "사용 GPU ID 목록", example = "0,1,2,3") - private String gpuIds; // gpu_ids - - @Schema(description = "분산학습 마스터 포트", example = "1122") - private Integer masterPort; // master_port - - @Schema(description = "메모", example = "하이퍼파라미터 신규등록") - private String memo; // memo - - @Schema(description = "삭제 여부(Y/N)", example = "N") - private String delYn; // del_yn - - @Schema(description = "생성 일시") - private LocalDate createdDttm; // created_dttm - } - @Schema(name = "TrainStartReq", description = "학습 시작 요청") @Getter @Setter diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java index b12c9fb..8c406ae 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java @@ -52,7 +52,8 @@ public class ModelTrainService { } // 3. 하이퍼파라미터 목록 - List hyperParams = hyperParamCoreService.findAllActiveHyperParams(); + List hyperParams = + null; // hyperParamCoreService.findAllActiveHyperParams(); // 4. 데이터셋 목록 List datasets = datasetCoreService.findAllActiveDatasetsForTraining(); @@ -65,26 +66,6 @@ public class ModelTrainService { .build(); } - /** - * 하이퍼파라미터 등록 - * - * @param createReq 등록 요청 - * @return 생성된 버전명 - */ - @Transactional - public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) { - // 신규 버전 추가 시 baseHyperVer가 없으면 H1으로 설정 - if (createReq.getBaseHyperVer() == null || createReq.getBaseHyperVer().isEmpty()) { - String firstVersion = hyperParamCoreService.getFirstHyperParamVersion(); - createReq.setBaseHyperVer(firstVersion); - log.info("baseHyperVer가 없어 첫 번째 버전으로 설정: {}", firstVersion); - } - - String newVersion = hyperParamCoreService.createHyperParam(createReq); - log.info("하이퍼파라미터 등록 완료: {}", newVersion); - return newVersion; - } - /** * 하이퍼파라미터 단건 조회 * @@ -245,10 +226,10 @@ public class ModelTrainService { .lastEpoch(entity.getLastCheckpointEpoch()) .totalEpoch(entity.getEpochCnt()) .checkpointPath(entity.getCheckpointPath()) - .failedAt( - entity.getStopDttm() != null - ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault()) - : null) + // .failedAt( + // entity.getStopDttm() != null + // ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault()) + // : null) .build(); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java index 9d8dfeb..d03929b 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java @@ -1,31 +1,24 @@ package com.kamco.cd.training.postgres.core; import com.kamco.cd.training.common.exception.BadRequestException; -import com.kamco.cd.training.common.exception.NotFoundException; +import com.kamco.cd.training.common.exception.CustomApiException; +import com.kamco.cd.training.common.utils.UserUtil; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.model.dto.ModelMngDto; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; -import com.kamco.cd.training.postgres.repository.model.ModelHyperParamRepository; +import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository; import java.time.ZonedDateTime; -import java.util.List; +import java.util.UUID; import lombok.RequiredArgsConstructor; +import org.springframework.data.domain.Page; +import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; @Service @RequiredArgsConstructor public class HyperParamCoreService { - private final ModelHyperParamRepository hyperParamRepository; - - /** - * 하이퍼파라미터 전체 조회 (삭제되지 않은 것만) - * - * @return 하이퍼파라미터 목록 - */ - public List findAllActiveHyperParams() { - List entities = - hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N"); - - return entities.stream().map(this::mapToHyperParamInfo).toList(); - } + private final HyperParamRepository hyperParamRepository; + private final UserUtil userUtil; private ModelMngDto.HyperParamInfo mapToHyperParamInfo(ModelHyperParamEntity entity) { return ModelMngDto.HyperParamInfo.builder() @@ -76,8 +69,6 @@ public class HyperParamCoreService { .contrastRange(entity.getContrastRange()) .saturationRange(entity.getSaturationRange()) .hueDelta(entity.getHueDelta()) - // Legacy - .cnnFilterCnt(entity.getCnnFilterCnt()) // Common .memo(entity.getMemo()) .createdDttm(entity.getCreatedDttm()) @@ -90,168 +81,104 @@ public class HyperParamCoreService { * @param createReq 등록 요청 * @return 등록된 버전명 */ - public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) { - // 중복 체크 - if (hyperParamRepository.existsById(createReq.getNewHyperVer())) { - throw new BadRequestException("이미 존재하는 버전입니다: " + createReq.getNewHyperVer()); - } + public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) { + String firstVersion = getFirstHyperParamVersion(); - // 기준 버전 조회 - ModelHyperParamEntity baseEntity = - hyperParamRepository - .findById(createReq.getBaseHyperVer()) - .orElseThrow( - () -> new NotFoundException("기준 버전을 찾을 수 없습니다: " + createReq.getBaseHyperVer())); - - // 신규 엔티티 생성 (기준 값 복사 후 변경된 값만 적용) ModelHyperParamEntity entity = new ModelHyperParamEntity(); - entity.setHyperVer(createReq.getNewHyperVer()); + entity.setHyperVer(firstVersion); + applyHyperParam(entity, createReq); + + // user + entity.setCreatedUid(userUtil.getId()); + + ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity); + return resultEntity.getHyperVer(); + } + + /** + * 하이퍼파라미터 수정 + * + * @param uuid uuid + * @param createReq 등록 요청 + * @return ver + */ + public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) { + ModelHyperParamEntity entity = + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + applyHyperParam(entity, createReq); + + // user + entity.setUpdatedUid(userUtil.getId()); + entity.setUpdatedDttm(ZonedDateTime.now()); + + return entity.getHyperVer(); + } + + private void applyHyperParam( + ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) { // Important - entity.setBackbone( - createReq.getBackbone() != null ? createReq.getBackbone() : baseEntity.getBackbone()); - entity.setInputSize( - createReq.getInputSize() != null ? createReq.getInputSize() : baseEntity.getInputSize()); - entity.setCropSize( - createReq.getCropSize() != null ? createReq.getCropSize() : baseEntity.getCropSize()); - entity.setEpochCnt( - createReq.getEpochCnt() != null ? createReq.getEpochCnt() : baseEntity.getEpochCnt()); - entity.setBatchSize( - createReq.getBatchSize() != null ? createReq.getBatchSize() : baseEntity.getBatchSize()); - - // Architecture - entity.setDropPathRate( - createReq.getDropPathRate() != null - ? createReq.getDropPathRate() - : baseEntity.getDropPathRate()); - entity.setFrozenStages( - createReq.getFrozenStages() != null - ? createReq.getFrozenStages() - : baseEntity.getFrozenStages()); - entity.setNeckPolicy( - createReq.getNeckPolicy() != null ? createReq.getNeckPolicy() : baseEntity.getNeckPolicy()); - entity.setDecoderChannels( - createReq.getDecoderChannels() != null - ? createReq.getDecoderChannels() - : baseEntity.getDecoderChannels()); - entity.setClassWeight( - createReq.getClassWeight() != null - ? createReq.getClassWeight() - : baseEntity.getClassWeight()); - entity.setNumLayers( - createReq.getNumLayers() != null ? createReq.getNumLayers() : baseEntity.getNumLayers()); - - // Optimization - entity.setLearningRate( - createReq.getLearningRate() != null - ? createReq.getLearningRate() - : baseEntity.getLearningRate()); - entity.setWeightDecay( - createReq.getWeightDecay() != null - ? createReq.getWeightDecay() - : baseEntity.getWeightDecay()); - entity.setLayerDecayRate( - createReq.getLayerDecayRate() != null - ? createReq.getLayerDecayRate() - : baseEntity.getLayerDecayRate()); - entity.setDdpFindUnusedParams( - createReq.getDdpFindUnusedParams() != null - ? createReq.getDdpFindUnusedParams() - : baseEntity.getDdpFindUnusedParams()); - entity.setIgnoreIndex( - createReq.getIgnoreIndex() != null - ? createReq.getIgnoreIndex() - : baseEntity.getIgnoreIndex()); + entity.setBackbone(createReq.getBackbone()); + entity.setInputSize(createReq.getInputSize()); + entity.setCropSize(createReq.getCropSize()); + entity.setBatchSize(createReq.getBatchSize()); // Data - entity.setTrainNumWorkers( - createReq.getTrainNumWorkers() != null - ? createReq.getTrainNumWorkers() - : baseEntity.getTrainNumWorkers()); - entity.setValNumWorkers( - createReq.getValNumWorkers() != null - ? createReq.getValNumWorkers() - : baseEntity.getValNumWorkers()); - entity.setTestNumWorkers( - createReq.getTestNumWorkers() != null - ? createReq.getTestNumWorkers() - : baseEntity.getTestNumWorkers()); - entity.setTrainShuffle( - createReq.getTrainShuffle() != null - ? createReq.getTrainShuffle() - : baseEntity.getTrainShuffle()); - entity.setTrainPersistent( - createReq.getTrainPersistent() != null - ? createReq.getTrainPersistent() - : baseEntity.getTrainPersistent()); - entity.setValPersistent( - createReq.getValPersistent() != null - ? createReq.getValPersistent() - : baseEntity.getValPersistent()); + entity.setTrainNumWorkers(createReq.getTrainNumWorkers()); + entity.setValNumWorkers(createReq.getValNumWorkers()); + entity.setTestNumWorkers(createReq.getTestNumWorkers()); + entity.setTrainShuffle(createReq.getTrainShuffle()); + entity.setTrainPersistent(createReq.getTrainPersistent()); + entity.setValPersistent(createReq.getValPersistent()); + + // Model Architecture + entity.setDropPathRate(createReq.getDropPathRate()); + entity.setFrozenStages(createReq.getFrozenStages()); + entity.setNeckPolicy(createReq.getNeckPolicy()); + entity.setClassWeight(createReq.getClassWeight()); + entity.setDecoderChannels(createReq.getDecoderChannels()); + + // Loss & Optimization + entity.setLearningRate(createReq.getLearningRate()); + entity.setWeightDecay(createReq.getWeightDecay()); + entity.setLayerDecayRate(createReq.getLayerDecayRate()); + entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams()); + entity.setIgnoreIndex(createReq.getIgnoreIndex()); + entity.setNumLayers(createReq.getNumLayers()); // Evaluation - entity.setMetrics( - createReq.getMetrics() != null ? createReq.getMetrics() : baseEntity.getMetrics()); - entity.setSaveBest( - createReq.getSaveBest() != null ? createReq.getSaveBest() : baseEntity.getSaveBest()); - entity.setSaveBestRule( - createReq.getSaveBestRule() != null - ? createReq.getSaveBestRule() - : baseEntity.getSaveBestRule()); - entity.setValInterval( - createReq.getValInterval() != null - ? createReq.getValInterval() - : baseEntity.getValInterval()); - entity.setLogInterval( - createReq.getLogInterval() != null - ? createReq.getLogInterval() - : baseEntity.getLogInterval()); - entity.setVisInterval( - createReq.getVisInterval() != null - ? createReq.getVisInterval() - : baseEntity.getVisInterval()); + entity.setMetrics(createReq.getMetrics()); + entity.setSaveBest(createReq.getSaveBest()); + entity.setSaveBestRule(createReq.getSaveBestRule()); + entity.setValInterval(createReq.getValInterval()); + entity.setLogInterval(createReq.getLogInterval()); + entity.setVisInterval(createReq.getVisInterval()); - // Hardware - entity.setGpuCnt( - createReq.getGpuCnt() != null ? createReq.getGpuCnt() : baseEntity.getGpuCnt()); - entity.setGpuIds( - createReq.getGpuIds() != null ? createReq.getGpuIds() : baseEntity.getGpuIds()); - entity.setMasterPort( - createReq.getMasterPort() != null ? createReq.getMasterPort() : baseEntity.getMasterPort()); - - // Augmentation - entity.setRotProb( - createReq.getRotProb() != null ? createReq.getRotProb() : baseEntity.getRotProb()); - entity.setFlipProb( - createReq.getFlipProb() != null ? createReq.getFlipProb() : baseEntity.getFlipProb()); - entity.setRotDegree( - createReq.getRotDegree() != null ? createReq.getRotDegree() : baseEntity.getRotDegree()); - entity.setExchangeProb( - createReq.getExchangeProb() != null - ? createReq.getExchangeProb() - : baseEntity.getExchangeProb()); - entity.setBrightnessDelta( - createReq.getBrightnessDelta() != null - ? createReq.getBrightnessDelta() - : baseEntity.getBrightnessDelta()); - entity.setContrastRange( - createReq.getContrastRange() != null - ? createReq.getContrastRange() - : baseEntity.getContrastRange()); - entity.setSaturationRange( - createReq.getSaturationRange() != null - ? createReq.getSaturationRange() - : baseEntity.getSaturationRange()); - entity.setHueDelta( - createReq.getHueDelta() != null ? createReq.getHueDelta() : baseEntity.getHueDelta()); - - // Common + // memo entity.setMemo(createReq.getMemo()); - entity.setDelYn("N"); - entity.setCreatedDttm(ZonedDateTime.now()); + } - ModelHyperParamEntity saved = hyperParamRepository.save(entity); - return saved.getHyperVer(); + /** + * 하이퍼파라미터 삭제 + * + * @param uuid + */ + public void deleteHyperParam(UUID uuid) { + ModelHyperParamEntity entity = + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + + if (entity.getHyperVer().equals("HPs_0001")) { + throw new CustomApiException("CONFLICT", HttpStatus.CONFLICT, "HPs_0001 버전은 삭제할수 없습니다."); + } + + entity.setDelYn("Y"); + entity.setUpdatedUid(userUtil.getId()); + entity.setUpdatedDttm(ZonedDateTime.now()); } /** @@ -261,27 +188,16 @@ public class HyperParamCoreService { * @return 하이퍼파라미터 정보 */ public ModelMngDto.HyperParamInfo findByHyperVer(String hyperVer) { - ModelHyperParamEntity entity = - hyperParamRepository - .findById(hyperVer) - .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer)); + // ModelHyperParamEntity entity = + // hyperParamRepository + // .findById(hyperVer) + // .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer)); + // + // if ("Y".equals(entity.getDelYn())) { + // throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer); + // } - if ("Y".equals(entity.getDelYn())) { - throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer); - } - - return mapToHyperParamInfo(entity); - } - - /** - * 하이퍼파라미터 수정 (기존 버전은 수정 불가) - * - * @param hyperVer 하이퍼파라미터 버전 - * @param updateReq 수정 요청 - */ - public void updateHyperParam(String hyperVer, ModelMngDto.HyperParamCreateReq updateReq) { - // 기존 버전은 수정 불가 - throw new BadRequestException("기존 버전은 수정할 수 없습니다. 신규 버전을 생성해주세요."); + return mapToHyperParamInfo(null); } /** @@ -294,33 +210,47 @@ public class HyperParamCoreService { if ("H1".equals(hyperVer)) { throw new BadRequestException("H1은 디폴트 하이퍼파라미터 버전이므로 삭제할 수 없습니다."); } - - ModelHyperParamEntity entity = - hyperParamRepository - .findById(hyperVer) - .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer)); - - if ("Y".equals(entity.getDelYn())) { - throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer); - } - - // 논리 삭제 처리 - entity.setDelYn("Y"); - hyperParamRepository.save(entity); + // + // ModelHyperParamEntity entity = + // hyperParamRepository + // .findById(hyperVer) + // .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer)); + // + // if ("Y".equals(entity.getDelYn())) { + // throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer); + // } + // + // // 논리 삭제 처리 + // entity.setDelYn("Y"); + // hyperParamRepository.save(entity); } /** - * 첫 번째 하이퍼파라미터 버전 조회 (H1 확인용) + * 하이퍼파라미터 목록 조회 * - * @return 첫 번째 하이퍼파라미터 버전 + * @param req + * @return + */ + public Page findByHyperVerList(HyperParamDto.SearchReq req) { + return hyperParamRepository.findByHyperVerList(req); + } + + /** + * 하이퍼파라미터 버전 조회 + * + * @return ver */ public String getFirstHyperParamVersion() { - List entities = - hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N"); - if (entities.isEmpty()) { - throw new NotFoundException("하이퍼파라미터가 존재하지 않습니다."); - } - // 가장 오래된 것이 H1이므로 리스트의 마지막 요소 반환 - return entities.get(entities.size() - 1).getHyperVer(); + return hyperParamRepository + .findHyperParamVer() + .map(ModelHyperParamEntity::getHyperVer) + .map(this::increase) + .orElse("HPs_0001"); + } + + private String increase(String hyperVer) { + String prefix = "HPs_"; + int num = Integer.parseInt(hyperVer.substring(prefix.length())); + return prefix + String.format("%04d", num + 1); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java index 95c0a8a..c32cfec 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.postgres.entity; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import jakarta.persistence.*; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; @@ -28,7 +29,7 @@ public class ModelHyperParamEntity { @NotNull @UuidGenerator @Column(name = "uuid", nullable = false, updatable = false) - private UUID uuid; + private UUID uuid = UUID.randomUUID(); @Size(max = 50) @NotNull @@ -259,8 +260,7 @@ public class ModelHyperParamEntity { // ------------------------- /** Default: 4 */ - @NotNull - @Column(name = "gpu_cnt", nullable = false) + @Column(name = "gpu_cnt") private Integer gpuCnt; /** Default: 0,1,2,3 */ @@ -289,9 +289,86 @@ public class ModelHyperParamEntity { @Column(name = "created_dttm", nullable = false) private ZonedDateTime createdDttm = ZonedDateTime.now(); - @Column(name = "cnn_filter_cnt") - private Integer cnnFilterCnt; + @NotNull + @Column(name = "created_uid", nullable = false) + private Long createdUid; + + @ColumnDefault("CURRENT_TIMESTAMP") + @Column(name = "updated_dttm") + private ZonedDateTime updatedDttm; + + @Column(name = "updated_uid") + private Long updatedUid; + + @ColumnDefault("CURRENT_TIMESTAMP") + @Column(name = "last_used_dttm") + private ZonedDateTime lastUsedDttm; + + @Column(name = "m1_use_cnt") + private Long m1UseCnt = 0L; + + @Column(name = "m2_use_cnt") + private Long m2UseCnt = 0L; + + @Column(name = "m3_use_cnt") + private Long m3UseCnt = 0L; @OneToMany(mappedBy = "hyperParams", fetch = FetchType.LAZY) private Set trainMasters = new LinkedHashSet<>(); + + public HyperParamDto.Basic toDto() { + return new HyperParamDto.Basic( + this.id, + this.uuid, + this.hyperVer, + this.backbone, + this.inputSize, + this.cropSize, + this.epochCnt, + this.batchSize, + this.dropPathRate, + this.frozenStages, + this.neckPolicy, + this.decoderChannels, + this.classWeight, + this.numLayers, + this.learningRate, + this.weightDecay, + this.layerDecayRate, + this.ddpFindUnusedParams, + this.ignoreIndex, + this.trainNumWorkers, + this.valNumWorkers, + this.testNumWorkers, + this.trainShuffle, + this.trainPersistent, + this.valPersistent, + this.metrics, + this.saveBest, + this.saveBestRule, + this.valInterval, + this.logInterval, + this.visInterval, + this.rotProb, + this.flipProb, + this.rotDegree, + this.exchangeProb, + this.brightnessDelta, + this.contrastRange, + this.saturationRange, + this.hueDelta, + this.gpuCnt, + this.gpuIds, + this.masterPort, + this.memo, + this.delYn, + this.createdDttm, + this.createdUid, + this.updatedDttm, + this.updatedUid, + this.lastUsedDttm, + this.m1UseCnt, + this.m2UseCnt, + this.m3UseCnt); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainMasterEntity.java index 43d8ec8..99a9fb9 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainMasterEntity.java @@ -12,7 +12,6 @@ import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; -import java.time.Instant; import java.time.ZonedDateTime; import java.util.UUID; import lombok.Getter; @@ -112,7 +111,7 @@ public class ModelTrainMasterEntity { private Integer progressRate; @Column(name = "stop_dttm") - private Instant stopDttm; + private ZonedDateTime stopDttm; @Column(name = "confirmed_best_epoch") private Integer confirmedBestEpoch; @@ -125,7 +124,7 @@ public class ModelTrainMasterEntity { private String errorMsg; @Column(name = "step2_start_dttm") - private Instant step2StartDttm; + private ZonedDateTime step2StartDttm; @Size(max = 1000) @Column(name = "train_log_path", length = 1000) diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepository.java new file mode 100644 index 0000000..529bc81 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepository.java @@ -0,0 +1,9 @@ +package com.kamco.cd.training.postgres.repository.hyperparam; + +import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.stereotype.Repository; + +@Repository +public interface HyperParamRepository + extends JpaRepository, HyperParamRepositoryCustom {} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryCustom.java new file mode 100644 index 0000000..7cdf240 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryCustom.java @@ -0,0 +1,21 @@ +package com.kamco.cd.training.postgres.repository.hyperparam; + +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; +import java.util.Optional; +import java.util.UUID; +import org.springframework.data.domain.Page; + +public interface HyperParamRepositoryCustom { + + /** + * 마지막 버전 조회 + * + * @return + */ + Optional findHyperParamVer(); + + Optional findHyperParamByUuid(UUID uuid); + + Page findByHyperVerList(HyperParamDto.SearchReq req); +} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java new file mode 100644 index 0000000..d0aa583 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/hyperparam/HyperParamRepositoryImpl.java @@ -0,0 +1,148 @@ +package com.kamco.cd.training.postgres.repository.hyperparam; + +import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; + +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType; +import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; +import com.querydsl.core.BooleanBuilder; +import com.querydsl.core.types.Projections; +import com.querydsl.core.types.dsl.NumberExpression; +import com.querydsl.jpa.impl.JPAQuery; +import com.querydsl.jpa.impl.JPAQueryFactory; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageImpl; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.stereotype.Repository; + +@Repository +@RequiredArgsConstructor +public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom { + + private final JPAQueryFactory queryFactory; + + @Override + public Optional findHyperParamVer() { + + return Optional.ofNullable( + queryFactory + .select(modelHyperParamEntity) + .from(modelHyperParamEntity) + .where(modelHyperParamEntity.delYn.eq("N")) + .orderBy(modelHyperParamEntity.hyperVer.desc()) + .limit(1) + .fetchOne()); + } + + @Override + public Optional findHyperParamByUuid(UUID uuid) { + return Optional.ofNullable( + queryFactory + .select(modelHyperParamEntity) + .from(modelHyperParamEntity) + .where(modelHyperParamEntity.delYn.eq("N").and(modelHyperParamEntity.uuid.eq(uuid))) + .fetchOne()); + } + + @Override + public Page findByHyperVerList(HyperParamDto.SearchReq req) { + Pageable pageable = req.toPageable(); + + BooleanBuilder builder = new BooleanBuilder(); + builder.and(modelHyperParamEntity.delYn.eq("N")); + + if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) { + // 버전 + builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer())); + } + + if (req.getStartDate() != null && req.getEndDate() != null) { + + ZoneId zoneId = ZoneId.systemDefault(); + + ZonedDateTime start = req.getStartDate().atStartOfDay(zoneId); + + ZonedDateTime end = req.getEndDate().atTime(23, 59, 59).atZone(zoneId); + + if (HyperType.CREATE_DATE.getId().equals(req.getType())) { + // 생성일 + builder.and(modelHyperParamEntity.createdDttm.between(start, end)); + } else if (HyperType.LAST_USED_DATE.getId().equals(req.getType())) { + // 최종 사용일 + builder.and(modelHyperParamEntity.lastUsedDttm.between(start, end)); + } + } + + NumberExpression totalUseCnt = + modelHyperParamEntity + .m1UseCnt + .coalesce(0L) + .add(modelHyperParamEntity.m2UseCnt.coalesce(0L)) + .add(modelHyperParamEntity.m3UseCnt.coalesce(0L)); + + JPAQuery query = + queryFactory + .select( + Projections.constructor( + HyperParamDto.List.class, + modelHyperParamEntity.uuid, + modelHyperParamEntity.hyperVer, + modelHyperParamEntity.createdDttm, + modelHyperParamEntity.lastUsedDttm, + modelHyperParamEntity.m1UseCnt, + modelHyperParamEntity.m2UseCnt, + modelHyperParamEntity.m3UseCnt, + totalUseCnt.as("totalUseCnt"))) + .from(modelHyperParamEntity) + .where(builder); + + Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null); + + if (sortOrder == null) { + // 기본값 + query.orderBy(modelHyperParamEntity.createdDttm.desc()); + } else { + String property = sortOrder.getProperty(); + boolean asc = sortOrder.isAscending(); + + switch (property) { + case "createdDttm" -> + query.orderBy( + asc + ? modelHyperParamEntity.createdDttm.asc() + : modelHyperParamEntity.createdDttm.desc()); + + case "lastUsedDttm" -> + query.orderBy( + asc + ? modelHyperParamEntity.lastUsedDttm.asc() + : modelHyperParamEntity.lastUsedDttm.desc()); + + case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc()); + + default -> query.orderBy(modelHyperParamEntity.createdDttm.desc()); + } + } + + List content = + query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch(); + + Long total = + queryFactory + .select(modelHyperParamEntity.count()) + .from(modelHyperParamEntity) + .where(builder) + .fetchOne(); + + long totalCount = (total != null) ? total : 0L; + + return new PageImpl<>(content, pageable, totalCount); + } +} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelHyperParamRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelHyperParamRepository.java deleted file mode 100644 index 32fc47b..0000000 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelHyperParamRepository.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.kamco.cd.training.postgres.repository.model; - -import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; -import java.util.List; -import org.springframework.data.jpa.repository.JpaRepository; -import org.springframework.stereotype.Repository; - -@Repository -public interface ModelHyperParamRepository extends JpaRepository { - - List findByDelYnOrderByCreatedDttmDesc(String delYn); - - List findByDelYnOrderByCreatedDttmAsc(String delYn); -}