From 6cdf4efda6ce007f39d041069fe741f50e149b19 Mon Sep 17 00:00:00 2001 From: teddy Date: Wed, 4 Feb 2026 18:00:56 +0900 Subject: [PATCH] =?UTF-8?q?=EB=8D=B0=EC=9D=B4=ED=84=B0=EC=85=8B=20?= =?UTF-8?q?=EB=93=B1=EB=A1=9D=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cd/training/common/dto/HyperParam.java | 154 ++++++++ .../cd/training/common/enums/ModelType.java | 27 ++ .../hyperparam/HyperParamApiController.java | 6 +- .../hyperparam/dto/HyperParamDto.java | 149 -------- .../hyperparam/service/HyperParamService.java | 7 +- ...r.java => ModelTrainMngApiController.java} | 25 +- .../cd/training/model/dto/ModelTrainDto.java | 336 ------------------ .../training/model/dto/ModelTrainMngDto.java | 195 ++++++++++ .../model/service/ModelMngService.java | 54 --- .../model/service/ModelTrainMngService.java | 87 +++++ .../model/service/ModelTrainService.java | 4 +- .../postgres/core/HyperParamCoreService.java | 72 ++-- ...ice.java => ModelTrainMngCoreService.java} | 121 ++++++- .../postgres/entity/ModelConfigEntity.java | 2 +- .../postgres/entity/ModelDatasetEntity.java | 9 +- .../postgres/entity/ModelMasterEntity.java | 12 +- .../model/ModelConfigRepository.java | 6 + .../model/ModelDatasetRepository.java | 7 + .../model/ModelDatasetRepositoryCustom.java | 3 + .../model/ModelMngRepositoryCustom.java | 4 +- .../model/ModelMngRepositoryImpl.java | 4 +- 21 files changed, 667 insertions(+), 617 deletions(-) create mode 100644 src/main/java/com/kamco/cd/training/common/dto/HyperParam.java create mode 100644 src/main/java/com/kamco/cd/training/common/enums/ModelType.java rename src/main/java/com/kamco/cd/training/model/{ModelMngApiController.java => ModelTrainMngApiController.java} (94%) delete mode 100644 src/main/java/com/kamco/cd/training/model/dto/ModelTrainDto.java create mode 100644 src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java delete mode 100644 src/main/java/com/kamco/cd/training/model/service/ModelMngService.java create mode 100644 src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java rename src/main/java/com/kamco/cd/training/postgres/core/{ModelMngCoreService.java => ModelTrainMngCoreService.java} (54%) create mode 100644 src/main/java/com/kamco/cd/training/postgres/repository/model/ModelConfigRepository.java create mode 100644 src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepository.java create mode 100644 src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepositoryCustom.java diff --git a/src/main/java/com/kamco/cd/training/common/dto/HyperParam.java b/src/main/java/com/kamco/cd/training/common/dto/HyperParam.java new file mode 100644 index 0000000..cfc78fc --- /dev/null +++ b/src/main/java/com/kamco/cd/training/common/dto/HyperParam.java @@ -0,0 +1,154 @@ +package com.kamco.cd.training.common.dto; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@AllArgsConstructor +@NoArgsConstructor +public class HyperParam { + // ------------------------- + // Important + // ------------------------- + @Schema(description = "백본 네트워크", example = "large") + private String backbone; // backbone + + @Schema(description = "입력 이미지 크기(H,W)", example = "512,512") + private String inputSize; // input_size + + @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") + private String cropSize; // crop_size + + @Schema(description = "배치 크기(Per GPU)", example = "16") + private Integer batchSize; // batch_size + + // ------------------------- + // Data + // ------------------------- + @Schema(description = "Train dataloader workers", example = "16") + private Integer trainNumWorkers; // train_num_workers + + @Schema(description = "Val dataloader workers", example = "8") + private Integer valNumWorkers; // val_num_workers + + @Schema(description = "Test dataloader workers", example = "8") + private Integer testNumWorkers; // test_num_workers + + @Schema(description = "Train shuffle 여부", example = "true") + private Boolean trainShuffle; // train_shuffle + + @Schema(description = "Train persistent workers 여부", example = "true") + private Boolean trainPersistent; // train_persistent + + @Schema(description = "Val persistent workers 여부", example = "true") + private Boolean valPersistent; // val_persistent + + // ------------------------- + // Model Architecture + // ------------------------- + @Schema(description = "Drop Path 비율", example = "0.3") + private Double dropPathRate; // drop_path_rate + + @Schema(description = "Freeze 단계(-1:None)", example = "-1") + private Integer frozenStages; // frozen_stages + + @Schema(description = "Neck 결합 정책", example = "abs_diff") + private String neckPolicy; // neck_policy + + @Schema(description = "디코더 채널 구성", example = "512,256,128,64") + private String decoderChannels; // decoder_channels + + @Schema(description = "클래스별 가중치", example = "1,10") + private String classWeight; // class_weight + + // ------------------------- + // Loss & Optimization + // ------------------------- + @Schema(description = "학습률", example = "0.00006") + private Double learningRate; // learning_rate + + @Schema(description = "Weight Decay", example = "0.05") + private Double weightDecay; // weight_decay + + @Schema(description = "Layer Decay Rate", example = "0.9") + private Double layerDecayRate; // layer_decay_rate + + @Schema(description = "DDP unused params 탐색 여부", example = "true") + private Boolean ddpFindUnusedParams; // ddp_find_unused_params + + @Schema(description = "Loss 계산 제외 인덱스", example = "255") + private Integer ignoreIndex; // ignore_index + + @Schema(description = "레이어 깊이", example = "24") + private Integer numLayers; // num_layers + + // ------------------------- + // Evaluation + // ------------------------- + @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") + private String metrics; // metrics + + @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") + private String saveBest; // save_best + + @Schema(description = "Best 모델 선정 규칙", example = "less") + private String saveBestRule; // save_best_rule + + @Schema(description = "검증 수행 주기(Epoch)", example = "10") + private Integer valInterval; // val_interval + + @Schema(description = "로그 기록 주기(Iteration)", example = "400") + private Integer logInterval; // log_interval + + @Schema(description = "시각화 저장 주기(Epoch)", example = "1") + private Integer visInterval; // vis_interval + + // ------------------------- + // Augmentation + // ------------------------- + @Schema(description = "회전 적용 확률", example = "0.5") + private Double rotProb; // rot_prob + + @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") + private String rotDegree; // rot_degree + + @Schema(description = "반전 적용 확률", example = "0.5") + private Double flipProb; // flip_prob + + @Schema(description = "채널 교환 확률", example = "0.5") + private Double exchangeProb; // exchange_prob + + @Schema(description = "밝기 변화량", example = "10") + private Integer brightnessDelta; // brightness_delta + + @Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2") + private String contrastRange; // contrast_range + + @Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2") + private String saturationRange; // saturation_range + + @Schema(description = "색조 변화량", example = "10") + private Integer hueDelta; // hue_delta + + // ------------------------- + // Hardware + // ------------------------- + @Schema(description = "사용 GPU 개수", example = "4") + private Integer gpuCnt; // gpu_cnt + + @Schema(description = "사용 GPU ID 목록", example = "0,1,2,3") + private String gpuIds; // gpu_ids + + @Schema(description = "분산학습 마스터 포트", example = "1122") + private Integer masterPort; // master_port + + // ------------------------- + // Memo + // ------------------------- + @Schema(description = "메모", example = "하이퍼파라미터 신규등록") + private String memo; // memo +} diff --git a/src/main/java/com/kamco/cd/training/common/enums/ModelType.java b/src/main/java/com/kamco/cd/training/common/enums/ModelType.java new file mode 100644 index 0000000..b8e3292 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/common/enums/ModelType.java @@ -0,0 +1,27 @@ +package com.kamco.cd.training.common.enums; + +import com.kamco.cd.training.common.utils.enums.CodeExpose; +import com.kamco.cd.training.common.utils.enums.EnumType; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@CodeExpose +@Getter +@AllArgsConstructor +public enum ModelType implements EnumType { + M1("M1"), + M2("M2"), + M3("M3"); + + private String desc; + + @Override + public String getId() { + return name(); + } + + @Override + public String getText() { + return desc; + } +} diff --git a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java index 972df82..aacd39a 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/HyperParamApiController.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.hyperparam; +import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; @@ -49,8 +50,7 @@ public class HyperParamApiController { @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping - public ApiResponseDto createHyperParam( - @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) { + public ApiResponseDto createHyperParam(@Valid @RequestBody HyperParam createReq) { String newVersion = hyperParamService.createHyperParam(createReq); return ApiResponseDto.ok(newVersion); } @@ -70,7 +70,7 @@ public class HyperParamApiController { }) @PutMapping("/{uuid}") public ApiResponseDto updateHyperParam( - @PathVariable UUID uuid, @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) { + @PathVariable UUID uuid, @Valid @RequestBody HyperParam createReq) { return ApiResponseDto.ok(hyperParamService.updateHyperParam(uuid, createReq)); } diff --git a/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java index f6f096d..2ed4934 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/dto/HyperParamDto.java @@ -115,155 +115,6 @@ public class HyperParamDto { private Long totalCnt; } - @Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public static class HyperParamCreateReq { - - // ------------------------- - // Important - // ------------------------- - @Schema(description = "백본 네트워크", example = "large") - private String backbone; // backbone - - @Schema(description = "입력 이미지 크기(H,W)", example = "512,512") - private String inputSize; // input_size - - @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") - private String cropSize; // crop_size - - @Schema(description = "배치 크기(Per GPU)", example = "16") - private Integer batchSize; // batch_size - - // ------------------------- - // Data - // ------------------------- - @Schema(description = "Train dataloader workers", example = "16") - private Integer trainNumWorkers; // train_num_workers - - @Schema(description = "Val dataloader workers", example = "8") - private Integer valNumWorkers; // val_num_workers - - @Schema(description = "Test dataloader workers", example = "8") - private Integer testNumWorkers; // test_num_workers - - @Schema(description = "Train shuffle 여부", example = "true") - private Boolean trainShuffle; // train_shuffle - - @Schema(description = "Train persistent workers 여부", example = "true") - private Boolean trainPersistent; // train_persistent - - @Schema(description = "Val persistent workers 여부", example = "true") - private Boolean valPersistent; // val_persistent - - // ------------------------- - // Model Architecture - // ------------------------- - @Schema(description = "Drop Path 비율", example = "0.3") - private Double dropPathRate; // drop_path_rate - - @Schema(description = "Freeze 단계(-1:None)", example = "-1") - private Integer frozenStages; // frozen_stages - - @Schema(description = "Neck 결합 정책", example = "abs_diff") - private String neckPolicy; // neck_policy - - @Schema(description = "디코더 채널 구성", example = "512,256,128,64") - private String decoderChannels; // decoder_channels - - @Schema(description = "클래스별 가중치", example = "1,10") - private String classWeight; // class_weight - - // ------------------------- - // Loss & Optimization - // ------------------------- - @Schema(description = "학습률", example = "0.00006") - private Double learningRate; // learning_rate - - @Schema(description = "Weight Decay", example = "0.05") - private Double weightDecay; // weight_decay - - @Schema(description = "Layer Decay Rate", example = "0.9") - private Double layerDecayRate; // layer_decay_rate - - @Schema(description = "DDP unused params 탐색 여부", example = "true") - private Boolean ddpFindUnusedParams; // ddp_find_unused_params - - @Schema(description = "Loss 계산 제외 인덱스", example = "255") - private Integer ignoreIndex; // ignore_index - - @Schema(description = "레이어 깊이", example = "24") - private Integer numLayers; // num_layers - - // ------------------------- - // Evaluation - // ------------------------- - @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") - private String metrics; // metrics - - @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") - private String saveBest; // save_best - - @Schema(description = "Best 모델 선정 규칙", example = "less") - private String saveBestRule; // save_best_rule - - @Schema(description = "검증 수행 주기(Epoch)", example = "10") - private Integer valInterval; // val_interval - - @Schema(description = "로그 기록 주기(Iteration)", example = "400") - private Integer logInterval; // log_interval - - @Schema(description = "시각화 저장 주기(Epoch)", example = "1") - private Integer visInterval; // vis_interval - - // ------------------------- - // Augmentation - // ------------------------- - @Schema(description = "회전 적용 확률", example = "0.5") - private Double rotProb; // rot_prob - - @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") - private String rotDegree; // rot_degree - - @Schema(description = "반전 적용 확률", example = "0.5") - private Double flipProb; // flip_prob - - @Schema(description = "채널 교환 확률", example = "0.5") - private Double exchangeProb; // exchange_prob - - @Schema(description = "밝기 변화량", example = "10") - private Integer brightnessDelta; // brightness_delta - - @Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2") - private String contrastRange; // contrast_range - - @Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2") - private String saturationRange; // saturation_range - - @Schema(description = "색조 변화량", example = "10") - private Integer hueDelta; // hue_delta - - // ------------------------- - // Hardware - // ------------------------- - @Schema(description = "사용 GPU 개수", example = "4") - private Integer gpuCnt; // gpu_cnt - - @Schema(description = "사용 GPU ID 목록", example = "0,1,2,3") - private String gpuIds; // gpu_ids - - @Schema(description = "분산학습 마스터 포트", example = "1122") - private Integer masterPort; // master_port - - // ------------------------- - // Memo - // ------------------------- - @Schema(description = "메모", example = "하이퍼파라미터 신규등록") - private String memo; // memo - } - @Getter @Setter @NoArgsConstructor diff --git a/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java index da7a742..d738af3 100644 --- a/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java +++ b/src/main/java/com/kamco/cd/training/hyperparam/service/HyperParamService.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.hyperparam.service; +import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; import com.kamco.cd.training.postgres.core.HyperParamCoreService; @@ -33,8 +34,8 @@ public class HyperParamService { * @return 생성된 버전명 */ @Transactional - public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) { - return hyperParamCoreService.createHyperParam(createReq); + public String createHyperParam(HyperParam createReq) { + return hyperParamCoreService.createHyperParam(createReq).getHyperVer(); } /** @@ -44,7 +45,7 @@ public class HyperParamService { * @return */ @Transactional - public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) { + public String updateHyperParam(UUID uuid, HyperParam createReq) { return hyperParamCoreService.updateHyperParam(uuid, createReq); } diff --git a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java similarity index 94% rename from src/main/java/com/kamco/cd/training/model/ModelMngApiController.java rename to src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java index 4bc2c6c..88ce615 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelMngApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java @@ -1,10 +1,9 @@ package com.kamco.cd.training.model; import com.kamco.cd.training.config.api.ApiResponseDto; -import com.kamco.cd.training.model.dto.ModelTrainDto; -import com.kamco.cd.training.model.dto.ModelTrainDto.Basic; -import com.kamco.cd.training.model.service.ModelMngService; -import com.kamco.cd.training.model.service.ModelTrainService; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; +import com.kamco.cd.training.model.service.ModelTrainMngService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.media.Content; @@ -28,9 +27,8 @@ import org.springframework.web.bind.annotation.RestController; @RequiredArgsConstructor @Tag(name = "모델학습 관리", description = "어드민 홈 > 모델학습관리 > 모델관리 > 목록") @RequestMapping("/api/models") -public class ModelMngApiController { - private final ModelMngService modelMngService; - private final ModelTrainService modelTrainService; +public class ModelTrainMngApiController { + private final ModelTrainMngService modelTrainMngService; @Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API") @ApiResponses( @@ -55,8 +53,8 @@ public class ModelMngApiController { String status, @Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page, @Parameter(description = "페이지 크기") @RequestParam(defaultValue = "20") int size) { - ModelTrainDto.SearchReq searchReq = new ModelTrainDto.SearchReq(status, page, size); - return ApiResponseDto.ok(modelMngService.getModelList(searchReq)); + ModelTrainMngDto.SearchReq searchReq = new ModelTrainMngDto.SearchReq(status, page, size); + return ApiResponseDto.ok(modelTrainMngService.getModelList(searchReq)); } @Operation(summary = "학습 모델 삭제", description = "학습 모델 삭제 API") @@ -70,7 +68,7 @@ public class ModelMngApiController { @Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79") @PathVariable UUID uuid) { - modelMngService.deleteModelTrain(uuid); + modelTrainMngService.deleteModelTrain(uuid); return ApiResponseDto.ok(null); } @@ -81,8 +79,9 @@ public class ModelMngApiController { @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping - public ApiResponseDto createModelTrain(@RequestBody ModelTrainDto.AddReq modelTrainDto) { - return ApiResponseDto.ok(null); + public ApiResponseDto createModelTrain(@RequestBody ModelTrainMngDto.AddReq req) { + modelTrainMngService.createModelTrain(req); + return ApiResponseDto.ok("ok"); } // @@ -104,7 +103,7 @@ public class ModelMngApiController { // @Parameter(description = "모델 UUID", example = "b7e99739-6736-45f9-a224-8161ecddf287") // @PathVariable // String uuid) { - // return ApiResponseDto.ok(modelMngService.getModelDetailByUuid(uuid)); + // return ApiResponseDto.ok(modelTrainMngService.getModelDetailByUuid(uuid)); // } // // // ==================== 학습 모델학습관리 API (5종) ==================== diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDto.java deleted file mode 100644 index 5ef6c08..0000000 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainDto.java +++ /dev/null @@ -1,336 +0,0 @@ -package com.kamco.cd.training.model.dto; - -import com.kamco.cd.training.common.enums.TrainStatusType; -import com.kamco.cd.training.common.enums.TrainType; -import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; -import io.swagger.v3.oas.annotations.media.Schema; -import jakarta.validation.constraints.NotNull; -import java.time.Duration; -import java.time.ZonedDateTime; -import java.util.List; -import java.util.UUID; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; -import org.springframework.data.domain.PageRequest; -import org.springframework.data.domain.Pageable; - -public class ModelTrainDto { - @Schema(name = "모델학습관리 목록", description = "모델학습관리 목록") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - @Builder - public static class Basic { - - private Long id; - private UUID uuid; - private String modelVer; - @JsonFormatDttm private ZonedDateTime startDttm; - @JsonFormatDttm private ZonedDateTime step1StrtDttm; - @JsonFormatDttm private ZonedDateTime step1EndDttm; - @JsonFormatDttm private ZonedDateTime step2StrtDttm; - @JsonFormatDttm private ZonedDateTime step2EndDttm; - private String step1Status; - private String step2Status; - private String statusCd; - private String trainType; - - public String getStatusName() { - if (this.statusCd == null || this.statusCd.isBlank()) return null; - try { - return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName() - } catch (IllegalArgumentException e) { - return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) - } - } - - public String getStep1StatusName() { - if (this.step1Status == null || this.step1Status.isBlank()) return null; - try { - return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName() - } catch (IllegalArgumentException e) { - return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) - } - } - - public String getStep2StatusNAme() { - if (this.step2Status == null || this.step2Status.isBlank()) return null; - try { - return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName() - } catch (IllegalArgumentException e) { - return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) - } - } - - public String getTrainTypeName() { - if (this.trainType == null || this.trainType.isBlank()) return null; - try { - return TrainType.valueOf(this.trainType).getText(); // 또는 getName() - } catch (IllegalArgumentException e) { - return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) - } - } - - private String formatDuration(ZonedDateTime start, ZonedDateTime end) { - if (start == null || end == null) { - return null; - } - - long totalSeconds = Math.abs(Duration.between(start, end).getSeconds()); - - long hours = totalSeconds / 3600; - long minutes = (totalSeconds % 3600) / 60; - long seconds = totalSeconds % 60; - - return String.format("%d시간 %d분 %d초", hours, minutes, seconds); - } - - public String getStep1Duration() { - return formatDuration(this.step1StrtDttm, this.step1EndDttm); - } - - public String getStep2Duration() { - return formatDuration(this.step2StrtDttm, this.step2EndDttm); - } - } - - @Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public static class SearchReq { - - private String status; - // 페이징 파라미터 - private int page = 0; - private int size = 20; - - public Pageable toPageable() { - return PageRequest.of(page, size); - } - } - - @Schema(name = "addReq", description = "모델학습 관리 등록 파라미터") - @Getter - @Setter - @NoArgsConstructor - @AllArgsConstructor - public static class AddReq { - HyperParamDto hyperParam; - TrainingDataConfigDto trainingDataConfig; - EtcConfig etcConfig; - } - - /** 학습실행설정 하이퍼파라미터 설정 */ - @Schema(name = "하이퍼파라미터 설정", description = "학습실행설정 > 하이퍼파라미터 설정") - @Getter - @Setter - public static class HyperParamDto { - - @NotNull - @Schema( - description = "OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)", - example = "EXISTING") - private String hyperParamType; - - @Schema(description = "기존 파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10") - private String hyperUuid; - - // ------------------------- - // Important - // ------------------------- - @Schema(description = "백본 네트워크", example = "large") - private String backbone; // backbone - - @Schema(description = "입력 이미지 크기(H,W)", example = "512,512") - private String inputSize; // input_size - - @Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256") - private String cropSize; // crop_size - - @Schema(description = "배치 크기(Per GPU)", example = "16") - private Integer batchSize; // batch_size - - // ------------------------- - // Data - // ------------------------- - @Schema(description = "Train dataloader workers", example = "16") - private Integer trainNumWorkers; // train_num_workers - - @Schema(description = "Val dataloader workers", example = "8") - private Integer valNumWorkers; // val_num_workers - - @Schema(description = "Test dataloader workers", example = "8") - private Integer testNumWorkers; // test_num_workers - - @Schema(description = "Train shuffle 여부", example = "true") - private Boolean trainShuffle; // train_shuffle - - @Schema(description = "Train persistent workers 여부", example = "true") - private Boolean trainPersistent; // train_persistent - - @Schema(description = "Val persistent workers 여부", example = "true") - private Boolean valPersistent; // val_persistent - - // ------------------------- - // Model Architecture - // ------------------------- - @Schema(description = "Drop Path 비율", example = "0.3") - private Double dropPathRate; // drop_path_rate - - @Schema(description = "Freeze 단계(-1:None)", example = "-1") - private Integer frozenStages; // frozen_stages - - @Schema(description = "Neck 결합 정책", example = "abs_diff") - private String neckPolicy; // neck_policy - - @Schema(description = "디코더 채널 구성", example = "512,256,128,64") - private String decoderChannels; // decoder_channels - - @Schema(description = "클래스별 가중치", example = "1,10") - private String classWeight; // class_weight - - // ------------------------- - // Loss & Optimization - // ------------------------- - @Schema(description = "학습률", example = "0.00006") - private Double learningRate; // learning_rate - - @Schema(description = "Weight Decay", example = "0.05") - private Double weightDecay; // weight_decay - - @Schema(description = "Layer Decay Rate", example = "0.9") - private Double layerDecayRate; // layer_decay_rate - - @Schema(description = "DDP unused params 탐색 여부", example = "true") - private Boolean ddpFindUnusedParams; // ddp_find_unused_params - - @Schema(description = "Loss 계산 제외 인덱스", example = "255") - private Integer ignoreIndex; // ignore_index - - @Schema(description = "레이어 깊이", example = "24") - private Integer numLayers; // num_layers - - // ------------------------- - // Evaluation - // ------------------------- - @Schema(description = "평가 지표 목록", example = "mFscore,mIoU") - private String metrics; // metrics - - @Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore") - private String saveBest; // save_best - - @Schema(description = "Best 모델 선정 규칙", example = "less") - private String saveBestRule; // save_best_rule - - @Schema(description = "검증 수행 주기(Epoch)", example = "10") - private Integer valInterval; // val_interval - - @Schema(description = "로그 기록 주기(Iteration)", example = "400") - private Integer logInterval; // log_interval - - @Schema(description = "시각화 저장 주기(Epoch)", example = "1") - private Integer visInterval; // vis_interval - - // ------------------------- - // Augmentation - // ------------------------- - @Schema(description = "회전 적용 확률", example = "0.5") - private Double rotProb; // rot_prob - - @Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20") - private String rotDegree; // rot_degree - - @Schema(description = "반전 적용 확률", example = "0.5") - private Double flipProb; // flip_prob - - @Schema(description = "채널 교환 확률", example = "0.5") - private Double exchangeProb; // exchange_prob - - @Schema(description = "밝기 변화량", example = "10") - private Integer brightnessDelta; // brightness_delta - - @Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2") - private String contrastRange; // contrast_range - - @Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2") - private String saturationRange; // saturation_range - - @Schema(description = "색조 변화량", example = "10") - private Integer hueDelta; // hue_delta - - // ------------------------- - // Hardware - // ------------------------- - @Schema(description = "사용 GPU 개수", example = "4") - private Integer gpuCnt; // gpu_cnt - - @Schema(description = "사용 GPU ID 목록", example = "0,1,2,3") - private String gpuIds; // gpu_ids - - @Schema(description = "분산학습 마스터 포트", example = "1122") - private Integer masterPort; // master_port - - // ------------------------- - // Memo - // ------------------------- - @Schema(description = "메모", example = "하이퍼파라미터 신규등록") - private String memo; // memo - } - - @Getter - @Setter - public static class TrainingDataConfigDto { - Summary summary; - List datasetList; - } - - @Getter - @Setter - public static class Summary { - @Schema(description = "건물", example = "0") - private Long buildingCnt; - - @Schema(description = "컨테이너", example = "0") - private Long containerCnt; - - @Schema(description = "폐기물", example = "0") - private Long wasteCnt; - - @Schema( - description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 태양광, 기타", - example = "0") - private Long LandCoverCnt; - } - - @Getter - @Setter - public static class Dataset { - @Schema(description = "데이터셋 uuid") - private UUID uuid; - } - - @Getter - @Setter - public static class EtcConfig { - @Schema(description = "에폭 횟수", example = "0") - private Long epochCnt; - - @Schema(description = "학습데이터셋 비율 Training", example = "0") - private Integer trainingCnt; - - @Schema(description = "학습데이터셋 비율 Validation", example = "0") - private Integer validationCnt; - - @Schema(description = "학습데이터셋 비율 Test", example = "0") - private Integer testCnt; - - @Schema(description = "메모", example = "메모 입니다.") - private String memo; - } -} diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java new file mode 100644 index 0000000..4996f0d --- /dev/null +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -0,0 +1,195 @@ +package com.kamco.cd.training.model.dto; + +import com.kamco.cd.training.common.dto.HyperParam; +import com.kamco.cd.training.common.enums.TrainStatusType; +import com.kamco.cd.training.common.enums.TrainType; +import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.UUID; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; + +public class ModelTrainMngDto { + @Schema(name = "모델학습관리 목록", description = "모델학습관리 목록") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + @Builder + public static class Basic { + + private Long id; + private UUID uuid; + private String modelVer; + @JsonFormatDttm private ZonedDateTime startDttm; + @JsonFormatDttm private ZonedDateTime step1StrtDttm; + @JsonFormatDttm private ZonedDateTime step1EndDttm; + @JsonFormatDttm private ZonedDateTime step2StrtDttm; + @JsonFormatDttm private ZonedDateTime step2EndDttm; + private String step1Status; + private String step2Status; + private String statusCd; + private String trainType; + + public String getStatusName() { + if (this.statusCd == null || this.statusCd.isBlank()) return null; + try { + return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName() + } catch (IllegalArgumentException e) { + return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) + } + } + + public String getStep1StatusName() { + if (this.step1Status == null || this.step1Status.isBlank()) return null; + try { + return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName() + } catch (IllegalArgumentException e) { + return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) + } + } + + public String getStep2StatusNAme() { + if (this.step2Status == null || this.step2Status.isBlank()) return null; + try { + return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName() + } catch (IllegalArgumentException e) { + return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) + } + } + + public String getTrainTypeName() { + if (this.trainType == null || this.trainType.isBlank()) return null; + try { + return TrainType.valueOf(this.trainType).getText(); // 또는 getName() + } catch (IllegalArgumentException e) { + return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리) + } + } + + private String formatDuration(ZonedDateTime start, ZonedDateTime end) { + if (start == null || end == null) { + return null; + } + + long totalSeconds = Math.abs(Duration.between(start, end).getSeconds()); + + long hours = totalSeconds / 3600; + long minutes = (totalSeconds % 3600) / 60; + long seconds = totalSeconds % 60; + + return String.format("%d시간 %d분 %d초", hours, minutes, seconds); + } + + public String getStep1Duration() { + return formatDuration(this.step1StrtDttm, this.step1EndDttm); + } + + public String getStep2Duration() { + return formatDuration(this.step2StrtDttm, this.step2EndDttm); + } + } + + @Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class SearchReq { + + private String status; + // 페이징 파라미터 + private int page = 0; + private int size = 20; + + public Pageable toPageable() { + return PageRequest.of(page, size); + } + } + + @Schema(name = "addReq", description = "모델학습 관리 등록 파라미터") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class AddReq { + + @NotNull + @Schema(description = "모델 종류 M1, M2, M3", example = "M1") + private String modelNo; + + @NotNull + @Schema(description = "모델학습 실행 여부", example = "false") + private Boolean isStart; + + @NotNull + @Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL") + private String trainType; + + @NotNull + @Schema( + description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)", + example = "EXISTING") + private String hyperParamType; + + @Schema(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10") + private UUID hyperUuid; + + HyperParam hyperParam; + TrainingDataset trainingDataset; + ModelConfig modelConfig; + } + + @Getter + @Setter + public static class TrainingDataset { + Summary summary; + List datasetList; + } + + @Getter + @Setter + public static class Summary { + @Schema(description = "건물", example = "0") + private Long buildingCnt; + + @Schema(description = "컨테이너", example = "0") + private Long containerCnt; + + @Schema(description = "폐기물", example = "0") + private Long wasteCnt; + + @Schema( + description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 태양광, 기타", + example = "0") + private Long LandCoverCnt; + } + + @Getter + @Setter + public static class ModelConfig { + @Schema(description = "에폭 횟수", example = "0") + private Integer epochCnt; + + @Schema(description = "학습데이터셋 비율 Training", example = "0") + private Float trainingCnt; + + @Schema(description = "학습데이터셋 비율 Validation", example = "0") + private Float validationCnt; + + @Schema(description = "학습데이터셋 비율 Test", example = "0") + private Float testCnt; + + @Schema(description = "메모", example = "메모 입니다.") + private String memo; + } +} diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelMngService.java deleted file mode 100644 index 7123c3f..0000000 --- a/src/main/java/com/kamco/cd/training/model/service/ModelMngService.java +++ /dev/null @@ -1,54 +0,0 @@ -package com.kamco.cd.training.model.service; - -import com.kamco.cd.training.model.dto.ModelMngDto; -import com.kamco.cd.training.model.dto.ModelTrainDto; -import com.kamco.cd.training.model.dto.ModelTrainDto.SearchReq; -import com.kamco.cd.training.postgres.core.ModelMngCoreService; -import java.util.UUID; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.data.domain.Page; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; - -@Service -@RequiredArgsConstructor -@Transactional(readOnly = true) -@Slf4j -public class ModelMngService { - - private final ModelMngCoreService modelMngCoreService; - - /** - * 모델 목록 조회 - * - * @param searchReq 검색 조건 - * @return 페이징 처리된 모델 목록 - */ - public Page getModelList(SearchReq searchReq) { - return modelMngCoreService.findByModelList(searchReq); - } - - /** - * 학습모델 삭제 - * - * @param uuid - */ - public void deleteModelTrain(UUID uuid) { - modelMngCoreService.deleteModel(uuid); - } - - public String createModelTrain(ModelMngDto modelMngDto) { - return null; - } - - /** - * 모델 상세 조회 - * - * @param modelUid 모델 UID - * @return 모델 상세 정보 - */ - public ModelMngDto.Detail getModelDetail(Long modelUid) { - return modelMngCoreService.getModelDetail(modelUid); - } -} diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java new file mode 100644 index 0000000..05b2b72 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -0,0 +1,87 @@ +package com.kamco.cd.training.model.service; + +import com.kamco.cd.training.common.dto.HyperParam; +import com.kamco.cd.training.common.enums.HyperParamSelectType; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.model.dto.ModelMngDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq; +import com.kamco.cd.training.postgres.core.HyperParamCoreService; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.data.domain.Page; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional(readOnly = true) +@Slf4j +public class ModelTrainMngService { + + private final ModelTrainMngCoreService modelTrainMngCoreService; + private final HyperParamCoreService hyperParamCoreService; + + /** + * 모델학습 조회 + * + * @param searchReq 검색 조건 + * @return 페이징 처리된 모델 목록 + */ + public Page getModelList(SearchReq searchReq) { + return modelTrainMngCoreService.findByModelList(searchReq); + } + + /** + * 모델학습 삭제 + * + * @param uuid + */ + public void deleteModelTrain(UUID uuid) { + modelTrainMngCoreService.deleteModel(uuid); + } + + /** + * 모델학습 등록 + * + * @param req + * @return + */ + @Transactional + public void createModelTrain(ModelTrainMngDto.AddReq req) { + HyperParam hyperParam = req.getHyperParam(); + HyperParamDto.Basic hyper = new HyperParamDto.Basic(); + + /** OPTIMIZED(최적화 파라미터), EXISTING(기존 파라미터), NEW(신규 파라미터) * */ + if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) { + // 하이퍼파라미터 등록 + hyper = hyperParamCoreService.createHyperParam(hyperParam); + req.setHyperUuid(hyper.getUuid()); + } + + // 모델학습 테이블 저장 + Long modelId = modelTrainMngCoreService.saveModel(req); + + // 모델학습 데이터셋 저장 + modelTrainMngCoreService.saveModelDataset(modelId, req); + + // 모델 데이터셋 mapping 저장 + modelTrainMngCoreService.saveModelDatasetMap( + modelId, req.getTrainingDataset().getDatasetList()); + + // 모델 config 저장 + modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig()); + } + + /** + * 모델학습 상세 조회 + * + * @param modelUid 모델 UID + * @return 모델 상세 정보 + */ + public ModelMngDto.Detail getModelDetail(Long modelUid) { + return modelTrainMngCoreService.getModelDetail(modelUid); + } +} diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java index cd77122..8b5b73b 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainService.java @@ -4,7 +4,7 @@ import com.kamco.cd.training.common.exception.BadRequestException; import com.kamco.cd.training.common.exception.NotFoundException; import com.kamco.cd.training.model.dto.ModelMngDto; import com.kamco.cd.training.postgres.core.DatasetCoreService; -import com.kamco.cd.training.postgres.core.ModelMngCoreService; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.SystemMetricsCoreService; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import java.util.List; @@ -19,7 +19,7 @@ import org.springframework.transaction.annotation.Transactional; @Slf4j public class ModelTrainService { - private final ModelMngCoreService modelMngCoreService; + private final ModelTrainMngCoreService modelMngCoreService; private final DatasetCoreService datasetCoreService; private final SystemMetricsCoreService systemMetricsCoreService; diff --git a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java index 725c3f3..d3b1f4b 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/HyperParamCoreService.java @@ -1,8 +1,10 @@ package com.kamco.cd.training.postgres.core; +import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.hyperparam.dto.HyperParamDto; +import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository; import java.time.ZonedDateTime; @@ -21,10 +23,10 @@ public class HyperParamCoreService { /** * 하이퍼파라미터 등록 * - * @param createReq 등록 요청 + * @param createReq ModelTrainMngDto.HyperParamDto * @return 등록된 버전명 */ - public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) { + public Basic createHyperParam(HyperParam createReq) { String firstVersion = getFirstHyperParamVersion(); ModelHyperParamEntity entity = new ModelHyperParamEntity(); @@ -36,7 +38,10 @@ public class HyperParamCoreService { entity.setCreatedUid(userUtil.getId()); ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity); - return resultEntity.getHyperVer(); + Basic basic = new Basic(); + basic.setUuid(resultEntity.getUuid()); + basic.setHyperVer(resultEntity.getHyperVer()); + return basic; } /** @@ -46,7 +51,7 @@ public class HyperParamCoreService { * @param createReq 등록 요청 * @return ver */ - public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) { + public String updateHyperParam(UUID uuid, HyperParam createReq) { ModelHyperParamEntity entity = hyperParamRepository .findHyperParamByUuid(uuid) @@ -61,47 +66,46 @@ public class HyperParamCoreService { return entity.getHyperVer(); } - private void applyHyperParam( - ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) { + private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) { // Important - entity.setBackbone(createReq.getBackbone()); - entity.setInputSize(createReq.getInputSize()); - entity.setCropSize(createReq.getCropSize()); - entity.setBatchSize(createReq.getBatchSize()); + entity.setBackbone(src.getBackbone()); + entity.setInputSize(src.getInputSize()); + entity.setCropSize(src.getCropSize()); + entity.setBatchSize(src.getBatchSize()); // Data - entity.setTrainNumWorkers(createReq.getTrainNumWorkers()); - entity.setValNumWorkers(createReq.getValNumWorkers()); - entity.setTestNumWorkers(createReq.getTestNumWorkers()); - entity.setTrainShuffle(createReq.getTrainShuffle()); - entity.setTrainPersistent(createReq.getTrainPersistent()); - entity.setValPersistent(createReq.getValPersistent()); + entity.setTrainNumWorkers(src.getTrainNumWorkers()); + entity.setValNumWorkers(src.getValNumWorkers()); + entity.setTestNumWorkers(src.getTestNumWorkers()); + entity.setTrainShuffle(src.getTrainShuffle()); + entity.setTrainPersistent(src.getTrainPersistent()); + entity.setValPersistent(src.getValPersistent()); // Model Architecture - entity.setDropPathRate(createReq.getDropPathRate()); - entity.setFrozenStages(createReq.getFrozenStages()); - entity.setNeckPolicy(createReq.getNeckPolicy()); - entity.setClassWeight(createReq.getClassWeight()); - entity.setDecoderChannels(createReq.getDecoderChannels()); + entity.setDropPathRate(src.getDropPathRate()); + entity.setFrozenStages(src.getFrozenStages()); + entity.setNeckPolicy(src.getNeckPolicy()); + entity.setClassWeight(src.getClassWeight()); + entity.setDecoderChannels(src.getDecoderChannels()); // Loss & Optimization - entity.setLearningRate(createReq.getLearningRate()); - entity.setWeightDecay(createReq.getWeightDecay()); - entity.setLayerDecayRate(createReq.getLayerDecayRate()); - entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams()); - entity.setIgnoreIndex(createReq.getIgnoreIndex()); - entity.setNumLayers(createReq.getNumLayers()); + entity.setLearningRate(src.getLearningRate()); + entity.setWeightDecay(src.getWeightDecay()); + entity.setLayerDecayRate(src.getLayerDecayRate()); + entity.setDdpFindUnusedParams(src.getDdpFindUnusedParams()); + entity.setIgnoreIndex(src.getIgnoreIndex()); + entity.setNumLayers(src.getNumLayers()); // Evaluation - entity.setMetrics(createReq.getMetrics()); - entity.setSaveBest(createReq.getSaveBest()); - entity.setSaveBestRule(createReq.getSaveBestRule()); - entity.setValInterval(createReq.getValInterval()); - entity.setLogInterval(createReq.getLogInterval()); - entity.setVisInterval(createReq.getVisInterval()); + entity.setMetrics(src.getMetrics()); + entity.setSaveBest(src.getSaveBest()); + entity.setSaveBestRule(src.getSaveBestRule()); + entity.setValInterval(src.getValInterval()); + entity.setLogInterval(src.getLogInterval()); + entity.setVisInterval(src.getVisInterval()); // memo - entity.setMemo(createReq.getMemo()); + entity.setMemo(src.getMemo()); } /** diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java similarity index 54% rename from src/main/java/com/kamco/cd/training/postgres/core/ModelMngCoreService.java rename to src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 35a4fcd..1f5f084 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -1,15 +1,24 @@ package com.kamco.cd.training.postgres.core; +import com.kamco.cd.training.common.enums.ModelType; +import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.exception.BadRequestException; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.exception.NotFoundException; import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.model.dto.ModelMngDto; -import com.kamco.cd.training.model.dto.ModelTrainDto; -import com.kamco.cd.training.model.dto.ModelTrainDto.Basic; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; +import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset; +import com.kamco.cd.training.postgres.entity.ModelConfigEntity; +import com.kamco.cd.training.postgres.entity.ModelDatasetEntity; import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity; +import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; +import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository; +import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository; import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository; +import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository; import com.kamco.cd.training.postgres.repository.model.ModelMngRepository; import java.time.ZonedDateTime; import java.util.List; @@ -21,9 +30,12 @@ import org.springframework.stereotype.Service; @Service @RequiredArgsConstructor -public class ModelMngCoreService { +public class ModelTrainMngCoreService { private final ModelMngRepository modelMngRepository; - private final ModelDatasetMappRepository modelDatasetMappRepository; + private final ModelDatasetRepository modelDatasetRepository; + private final ModelDatasetMappRepository modelDatasetMapRepository; + private final ModelConfigRepository modelConfigRepository; + private final HyperParamRepository hyperParamRepository; private final UserUtil userUtil; /** @@ -32,7 +44,7 @@ public class ModelMngCoreService { * @param searchReq 검색 조건 * @return 페이징 처리된 모델 목록 */ - public Page findByModelList(ModelTrainDto.SearchReq searchReq) { + public Page findByModelList(ModelTrainMngDto.SearchReq searchReq) { Page entityPage = modelMngRepository.findByModels(searchReq); return entityPage.map(ModelMasterEntity::toDto); } @@ -52,6 +64,103 @@ public class ModelMngCoreService { entity.setUpdatedUid(userUtil.getId()); } + /** + * 모델학습 저장 + * + * @param addReq + * @return + */ + public Long saveModel(ModelTrainMngDto.AddReq addReq) { + ModelMasterEntity entity = new ModelMasterEntity(); + ModelHyperParamEntity hyperParamEntity = + hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null); + + entity.setModelNo(addReq.getModelNo()); + entity.setTrainType(addReq.getTrainType()); + + if (hyperParamEntity != null) { + entity.setHyperParamId(hyperParamEntity.getId()); + } + + if (addReq.getIsStart()) { + entity.setModelStep((short) 1); + entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); + entity.setStrtDttm(ZonedDateTime.now()); + entity.setStep1StrtDttm(ZonedDateTime.now()); + entity.setStep1State(TrainStatusType.IN_PROGRESS.getId()); + } + + entity.setCreatedUid(userUtil.getId()); + ModelMasterEntity resultEntity = modelMngRepository.save(entity); + return resultEntity.getId(); + } + + /** + * data set 저장 + * + * @param modelId 저장한 모델 학습 id + * @param addReq 요청 파라미터 + */ + public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) { + TrainingDataset dataset = addReq.getTrainingDataset(); + ModelMasterEntity modelMasterEntity = new ModelMasterEntity(); + ModelDatasetEntity datasetEntity = new ModelDatasetEntity(); + + modelMasterEntity.setId(modelId); + datasetEntity.setModel(modelMasterEntity); + + if (addReq.getModelNo().equals(ModelType.M1.getId())) { + datasetEntity.setBuildingCnt(dataset.getSummary().getBuildingCnt()); + datasetEntity.setContainerCnt(dataset.getSummary().getContainerCnt()); + } else if (addReq.getModelNo().equals(ModelType.M2.getId())) { + datasetEntity.setWasteCnt(dataset.getSummary().getWasteCnt()); + } else if (addReq.getModelNo().equals(ModelType.M3.getId())) { + datasetEntity.setLandCoverCnt(dataset.getSummary().getLandCoverCnt()); + } + + datasetEntity.setCreatedUid(userUtil.getId()); + + // data set 저장 + modelDatasetRepository.save(datasetEntity); + } + + /** + * 모델 데이터셋 mapping 테이블 저장 + * + * @param modelId 모델학습 id + * @param datasetList 선택한 data set + */ + public void saveModelDatasetMap(Long modelId, List datasetList) { + + for (Long datasetId : datasetList) { + ModelDatasetMappEntity mapEntity = new ModelDatasetMappEntity(); + mapEntity.setModelUid(modelId); + mapEntity.setDatasetUid(datasetId); + modelDatasetMapRepository.save(mapEntity); + } + } + + /** + * 모델학습 config 저장 + * + * @param modelId 모델학습 id + * @param req 요청 파라미터 + * @return + */ + public Long saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) { + ModelMasterEntity modelMasterEntity = new ModelMasterEntity(); + ModelConfigEntity entity = new ModelConfigEntity(); + modelMasterEntity.setId(modelId); + entity.setModel(modelMasterEntity); + entity.setEpochCount(req.getEpochCnt()); + entity.setTrainPercent(req.getTrainingCnt()); + entity.setValidationPercent(req.getValidationCnt()); + entity.setTestPercent(req.getTestCnt()); + entity.setMemo(req.getMemo()); + + return modelConfigRepository.save(entity).getId(); + } + /** * 모델 상세 조회 * @@ -136,7 +245,7 @@ public class ModelMngCoreService { mapping.setModelUid(modelUid); mapping.setDatasetUid(datasetId); mapping.setDatasetType("TRAIN"); - modelDatasetMappRepository.save(mapping); + modelDatasetMapRepository.save(mapping); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelConfigEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelConfigEntity.java index 0f5acbd..4df7493 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelConfigEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelConfigEntity.java @@ -24,7 +24,7 @@ public class ModelConfigEntity { @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "config_id", nullable = false) - private Integer id; + private Long id; @NotNull @ManyToOne(fetch = FetchType.LAZY, optional = false) diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelDatasetEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelDatasetEntity.java index e7c9255..0385a72 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelDatasetEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelDatasetEntity.java @@ -3,6 +3,8 @@ package com.kamco.cd.training.postgres.entity; import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.FetchType; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; import jakarta.persistence.Id; import jakarta.persistence.JoinColumn; import jakarta.persistence.ManyToOne; @@ -20,6 +22,7 @@ import org.hibernate.annotations.ColumnDefault; public class ModelDatasetEntity { @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name = "id", nullable = false) private Long id; @@ -28,10 +31,6 @@ public class ModelDatasetEntity { @JoinColumn(name = "model_id", nullable = false) private ModelMasterEntity model; - @NotNull - @Column(name = "data_id", nullable = false) - private Long dataId; - @Column(name = "building_cnt") private Long buildingCnt; @@ -46,7 +45,7 @@ public class ModelDatasetEntity { @ColumnDefault("now()") @Column(name = "created_dttm") - private ZonedDateTime createdDttm; + private ZonedDateTime createdDttm = ZonedDateTime.now(); @Column(name = "created_uid") private Long createdUid; diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java index 531e71b..1e9dee4 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java @@ -1,13 +1,12 @@ package com.kamco.cd.training.postgres.entity; -import com.kamco.cd.training.model.dto.ModelTrainDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.GeneratedValue; import jakarta.persistence.GenerationType; import jakarta.persistence.Id; import jakarta.persistence.Table; -import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; import java.time.ZonedDateTime; import java.util.UUID; @@ -26,8 +25,7 @@ public class ModelMasterEntity { @Column(name = "model_id", nullable = false) private Long id; - @NotNull - @Column(name = "hyper_param_id", nullable = false) + @Column(name = "hyper_param_id") private Long hyperParamId; @Size(max = 10) @@ -69,7 +67,7 @@ public class ModelMasterEntity { private String step2State; @Column(name = "del_yn") - private Boolean delYn; + private Boolean delYn = false; @ColumnDefault("now()") @Column(name = "created_dttm") @@ -90,8 +88,8 @@ public class ModelMasterEntity { @Column(name = "train_type") private String trainType; - public ModelTrainDto.Basic toDto() { - return new ModelTrainDto.Basic( + public ModelTrainMngDto.Basic toDto() { + return new ModelTrainMngDto.Basic( this.id, this.uuid, this.modelVer, diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelConfigRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelConfigRepository.java new file mode 100644 index 0000000..8b0d02e --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelConfigRepository.java @@ -0,0 +1,6 @@ +package com.kamco.cd.training.postgres.repository.model; + +import com.kamco.cd.training.postgres.entity.ModelConfigEntity; +import org.springframework.data.jpa.repository.JpaRepository; + +public interface ModelConfigRepository extends JpaRepository {} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepository.java new file mode 100644 index 0000000..4ccfdbf --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepository.java @@ -0,0 +1,7 @@ +package com.kamco.cd.training.postgres.repository.model; + +import com.kamco.cd.training.postgres.entity.ModelDatasetEntity; +import org.springframework.data.jpa.repository.JpaRepository; + +public interface ModelDatasetRepository + extends JpaRepository, ModelDatasetRepositoryCustom {} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepositoryCustom.java new file mode 100644 index 0000000..d29c40c --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelDatasetRepositoryCustom.java @@ -0,0 +1,3 @@ +package com.kamco.cd.training.postgres.repository.model; + +public interface ModelDatasetRepositoryCustom {} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java index 0abe3db..d664435 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java @@ -1,6 +1,6 @@ package com.kamco.cd.training.postgres.repository.model; -import com.kamco.cd.training.model.dto.ModelTrainDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import java.util.Optional; import java.util.UUID; @@ -14,7 +14,7 @@ public interface ModelMngRepositoryCustom { * @param searchReq * @return */ - Page findByModels(ModelTrainDto.SearchReq searchReq); + Page findByModels(ModelTrainMngDto.SearchReq searchReq); Optional findByUuid(UUID uuid); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index a08b37d..70fabb7 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -2,7 +2,7 @@ package com.kamco.cd.training.postgres.repository.model; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; -import com.kamco.cd.training.model.dto.ModelTrainDto; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.querydsl.core.BooleanBuilder; import com.querydsl.jpa.impl.JPAQueryFactory; @@ -28,7 +28,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { * @return */ @Override - public Page findByModels(ModelTrainDto.SearchReq req) { + public Page findByModels(ModelTrainMngDto.SearchReq req) { Pageable pageable = req.toPageable(); BooleanBuilder builder = new BooleanBuilder();