From 335e9d33d6e1a9b66a9161ef00c47417e2cefaf8 Mon Sep 17 00:00:00 2001 From: teddy Date: Tue, 3 Feb 2026 15:05:39 +0900 Subject: [PATCH] =?UTF-8?q?=ED=95=98=EC=9D=B4=ED=8D=BC=ED=8C=8C=EB=9D=BC?= =?UTF-8?q?=EB=AF=B8=ED=84=B0=20=EA=B8=B0=EB=8A=A5=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hyperparam/HyperParamApiController.java | 44 ++++- .../hyperparam/dto/HyperParamDto.java | 168 +++++++++--------- .../hyperparam/service/HyperParamService.java | 15 ++ .../training/model/ModelMngApiController.java | 19 -- .../postgres/core/HyperParamCoreService.java | 23 +++ .../entity/ModelHyperParamEntity.java | 72 +++++--- 6 files changed, 213 insertions(+), 128 deletions(-) diff --git a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java index 2d02307..760cfae 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java @@ -83,7 +83,7 @@ public class HyperParamApiController { content = @Content( mediaType = "application/json", - schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))), + schema = @Schema(implementation = Page.class))), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @@ -133,9 +133,8 @@ public class HyperParamApiController { @ApiResponses( value = { @ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content), - @ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content), + @ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), - @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @DeleteMapping("/{uuid}") public ApiResponseDto deleteHyperParam( @@ -145,4 +144,43 @@ public class HyperParamApiController { hyperParamService.deleteHyperParam(uuid); return ApiResponseDto.ok(null); } + + @Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))), + @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/{uuid}") + public ApiResponseDto getHyperParam( + @Parameter(description = "하이퍼파라미터 uuid", example = "9c91a20c-71e7-4e5f-a860-9626d2b2059c") + @PathVariable + UUID uuid) { + return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid)); + } + + @Operation(summary = "하이퍼파라미터 최적화 값 조회", description = "하이퍼파라미터 최적화 값 조회 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = Page.class))), + @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/init") + public ApiResponseDto getInitHyperParam() { + return ApiResponseDto.ok(hyperParamService.getInitHyperParam()); + } } diff --git a/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java index 6824e60..6b9c93c 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java @@ -23,9 +23,8 @@ public class HyperParamDto { @NoArgsConstructor @AllArgsConstructor public static class Basic { - private Long id; + private UUID uuid; - private String hyperVer; // ------------------------- // Important @@ -33,28 +32,8 @@ public class HyperParamDto { private String backbone; private String inputSize; private String cropSize; - private Integer epochCnt; private Integer batchSize; - // ------------------------- - // Model Architecture - // ------------------------- - private Double dropPathRate; - private Integer frozenStages; - private String neckPolicy; - private String decoderChannels; - private String classWeight; - private Integer numLayers; - - // ------------------------- - // Loss & Optimization - // ------------------------- - private Double learningRate; - private Double weightDecay; - private Double layerDecayRate; - private Boolean ddpFindUnusedParams; - private Integer ignoreIndex; - // ------------------------- // Data // ------------------------- @@ -65,6 +44,25 @@ public class HyperParamDto { private Boolean trainPersistent; private Boolean valPersistent; + // ------------------------- + // Model Architecture + // ------------------------- + private Double dropPathRate; + private Integer frozenStages; + private String neckPolicy; + private String decoderChannels; + private String classWeight; + + // ------------------------- + // Loss & Optimization + // ------------------------- + private Double learningRate; + private Double weightDecay; + private Double layerDecayRate; + private Boolean ddpFindUnusedParams; + private Integer ignoreIndex; + private Integer numLayers; + // ------------------------- // Evaluation // ------------------------- @@ -79,36 +77,25 @@ public class HyperParamDto { // Augmentation // ------------------------- private Double rotProb; - private Double flipProb; private String rotDegree; + private Double flipProb; private Double exchangeProb; private Integer brightnessDelta; private String contrastRange; private String saturationRange; private Integer hueDelta; + // ------------------------- + // Memo + // ------------------------- + private String memo; + // ------------------------- // Hardware // ------------------------- private Integer gpuCnt; private String gpuIds; private Integer masterPort; - - // ------------------------- - // Common - // ------------------------- - private String memo; - private String delYn; - - @JsonFormatDttm private ZonedDateTime createdDttm; - private Long createdUid; - @JsonFormatDttm private ZonedDateTime updatedDttm; - private Long updatedUid; - @JsonFormatDttm private ZonedDateTime lastUsedDttm; - - private Long m1UseCnt; - private Long m2UseCnt; - private Long m3UseCnt; } @Getter @@ -133,54 +120,24 @@ public class HyperParamDto { @AllArgsConstructor public static class HyperParamCreateReq { + // ------------------------- + // Important + // ------------------------- @Schema(description = "백본 네트워크", example = "large") private String backbone; // backbone - @Schema(description = "입력 이미지 크기(H,W)", example = "256,256") + @Schema(description = "입력 이미지 크기(H,W)", example = "512,512") private String inputSize; // input_size @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") private String cropSize; // crop_size - @Schema(description = "총 학습 에폭 수", example = "200") - private Integer epochCnt; // epoch_cnt - @Schema(description = "배치 크기(Per GPU)", example = "16") private Integer batchSize; // batch_size - @Schema(description = "Drop Path 비율", example = "0.3") - private Double dropPathRate; // drop_path_rate - - @Schema(description = "Freeze 단계(-1:None)", example = "-1") - private Integer frozenStages; // frozen_stages - - @Schema(description = "Neck 결합 정책", example = "abs_diff") - private String neckPolicy; // neck_policy - - @Schema(description = "디코더 채널 구성", example = "512,256,128,64") - private String decoderChannels; // decoder_channels - - @Schema(description = "클래스별 가중치", example = "1,10") - private String classWeight; // class_weight - - @Schema(description = "레이어 깊이", example = "24") - private Integer numLayers; // num_layers - - @Schema(description = "학습률", example = "0.00006") - private Double learningRate; // learning_rate - - @Schema(description = "Weight Decay", example = "0.05") - private Double weightDecay; // weight_decay - - @Schema(description = "Layer Decay Rate", example = "0.9") - private Double layerDecayRate; // layer_decay_rate - - @Schema(description = "DDP unused params 탐색 여부", example = "true") - private Boolean ddpFindUnusedParams; // ddp_find_unused_params - - @Schema(description = "Loss 계산 제외 인덱스", example = "255") - private Integer ignoreIndex; // ignore_index - + // ------------------------- + // Data + // ------------------------- @Schema(description = "Train dataloader workers", example = "16") private Integer trainNumWorkers; // train_num_workers @@ -199,13 +156,55 @@ public class HyperParamDto { @Schema(description = "Val persistent workers 여부", example = "true") private Boolean valPersistent; // val_persistent + // ------------------------- + // Model Architecture + // ------------------------- + @Schema(description = "Drop Path 비율", example = "0.3") + private Double dropPathRate; // drop_path_rate + + @Schema(description = "Freeze 단계(-1:None)", example = "-1") + private Integer frozenStages; // frozen_stages + + @Schema(description = "Neck 결합 정책", example = "abs_diff") + private String neckPolicy; // neck_policy + + @Schema(description = "디코더 채널 구성", example = "512,256,128,64") + private String decoderChannels; // decoder_channels + + @Schema(description = "클래스별 가중치", example = "1,10") + private String classWeight; // class_weight + + // ------------------------- + // Loss & Optimization + // ------------------------- + @Schema(description = "학습률", example = "0.00006") + private Double learningRate; // learning_rate + + @Schema(description = "Weight Decay", example = "0.05") + private Double weightDecay; // weight_decay + + @Schema(description = "Layer Decay Rate", example = "0.9") + private Double layerDecayRate; // layer_decay_rate + + @Schema(description = "DDP unused params 탐색 여부", example = "true") + private Boolean ddpFindUnusedParams; // ddp_find_unused_params + + @Schema(description = "Loss 계산 제외 인덱스", example = "255") + private Integer ignoreIndex; // ignore_index + + @Schema(description = "레이어 깊이", example = "24") + private Integer numLayers; // num_layers + + // ------------------------- + // Evaluation + // ------------------------- @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") private String metrics; // metrics @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") private String saveBest; // save_best - @Schema(description = "Best 모델 선정 규칙", example = "greater") + @Schema(description = "Best 모델 선정 규칙", example = "less") private String saveBestRule; // save_best_rule @Schema(description = "검증 수행 주기(Epoch)", example = "10") @@ -217,15 +216,18 @@ public class HyperParamDto { @Schema(description = "시각화 저장 주기(Epoch)", example = "1") private Integer visInterval; // vis_interval + // ------------------------- + // Augmentation + // ------------------------- @Schema(description = "회전 적용 확률", example = "0.5") private Double rotProb; // rot_prob - @Schema(description = "반전 적용 확률", example = "0.5") - private Double flipProb; // flip_prob - @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") private String rotDegree; // rot_degree + @Schema(description = "반전 적용 확률", example = "0.5") + private Double flipProb; // flip_prob + @Schema(description = "채널 교환 확률", example = "0.5") private Double exchangeProb; // exchange_prob @@ -241,6 +243,9 @@ public class HyperParamDto { @Schema(description = "색조 변화량", example = "10") private Integer hueDelta; // hue_delta + // ------------------------- + // Hardware + // ------------------------- @Schema(description = "사용 GPU 개수", example = "4") private Integer gpuCnt; // gpu_cnt @@ -250,6 +255,9 @@ public class HyperParamDto { @Schema(description = "분산학습 마스터 포트", example = "1122") private Integer masterPort; // master_port + // ------------------------- + // Memo + // ------------------------- @Schema(description = "메모", example = "하이퍼파라미터 신규등록") private String memo; // memo } diff --git a/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java index 6574a6a..73235d9 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java @@ -56,4 +56,19 @@ public class HyperParamService { public void deleteHyperParam(UUID uuid) { hyperParamCoreService.deleteHyperParam(uuid); } + + /** 하이퍼파라미터 최적화 설정값 조회 */ + public HyperParamDto.Basic getInitHyperParam() { + return hyperParamCoreService.getInitHyperParam(); + } + + /** + * 하이퍼파라미터 단건 조회 + * + * @param uuid + * @return + */ + public HyperParamDto.Basic getHyperParam(UUID uuid) { + return hyperParamCoreService.getHyperParam(uuid); + } } diff --git a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java b/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java index cc19ac3..5c0c065 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java @@ -112,25 +112,6 @@ public class ModelMngApiController { return ApiResponseDto.ok(modelTrainService.getFormConfig()); } - @Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다") - @ApiResponses( - value = { - @ApiResponse( - responseCode = "200", - description = "조회 성공", - content = - @Content( - mediaType = "application/json", - schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))), - @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), - @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) - }) - @GetMapping("/hyper-params/{hyperVer}") - public ApiResponseDto getHyperParam( - @Parameter(description = "하이퍼파라미터 버전", example = "H1") @PathVariable String hyperVer) { - return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer)); - } - @Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다") @ApiResponses( value = { diff --git a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java index d03929b..17b252c 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java @@ -181,6 +181,29 @@ public class HyperParamCoreService { entity.setUpdatedDttm(ZonedDateTime.now()); } + /** + * 하이퍼파라미터 최적화 설정값 조회 + * + * @return + */ + public HyperParamDto.Basic getInitHyperParam() { + ModelHyperParamEntity entity = new ModelHyperParamEntity(); + return entity.toDto(); + } + + /** + * 하이퍼파라미터 단건 조회 + * + * @return + */ + public HyperParamDto.Basic getHyperParam(UUID uuid) { + ModelHyperParamEntity entity = + hyperParamRepository + .findHyperParamByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + return entity.toDto(); + } + /** * 하이퍼파라미터 단건 조회 * diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java index c32cfec..b0d6ac6 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelHyperParamEntity.java @@ -318,57 +318,77 @@ public class ModelHyperParamEntity { public HyperParamDto.Basic toDto() { return new HyperParamDto.Basic( - this.id, this.uuid, - this.hyperVer, + + // ------------------------- + // Important + // ------------------------- this.backbone, this.inputSize, this.cropSize, - this.epochCnt, this.batchSize, - this.dropPathRate, - this.frozenStages, - this.neckPolicy, - this.decoderChannels, - this.classWeight, - this.numLayers, - this.learningRate, - this.weightDecay, - this.layerDecayRate, - this.ddpFindUnusedParams, - this.ignoreIndex, + + // ------------------------- + // Data + // ------------------------- this.trainNumWorkers, this.valNumWorkers, this.testNumWorkers, this.trainShuffle, this.trainPersistent, this.valPersistent, + + // ------------------------- + // Model Architecture + // ------------------------- + this.dropPathRate, + this.frozenStages, + this.neckPolicy, + this.decoderChannels, + this.classWeight, + + // ------------------------- + // Loss & Optimization + // ------------------------- + this.learningRate, + this.weightDecay, + this.layerDecayRate, + this.ddpFindUnusedParams, + this.ignoreIndex, + this.numLayers, + + // ------------------------- + // Evaluation + // ------------------------- this.metrics, this.saveBest, this.saveBestRule, this.valInterval, this.logInterval, this.visInterval, + + // ------------------------- + // Augmentation + // ------------------------- this.rotProb, - this.flipProb, this.rotDegree, + this.flipProb, this.exchangeProb, this.brightnessDelta, this.contrastRange, this.saturationRange, this.hueDelta, + + // ------------------------- + // Memo + // ------------------------- + this.memo, + + // ------------------------- + // Hardware + // ------------------------- this.gpuCnt, this.gpuIds, - this.masterPort, - this.memo, - this.delYn, - this.createdDttm, - this.createdUid, - this.updatedDttm, - this.updatedUid, - this.lastUsedDttm, - this.m1UseCnt, - this.m2UseCnt, - this.m3UseCnt); + this.masterPort); } }