Merge pull request 'feat/training_260202' (#16) from feat/training_260202 into develop
Reviewed-on: #16
This commit was merged in pull request #16.
This commit is contained in:
154
src/main/java/com/kamco/cd/training/common/dto/HyperParam.java
Normal file
154
src/main/java/com/kamco/cd/training/common/dto/HyperParam.java
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package com.kamco.cd.training.common.dto;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class HyperParam {
|
||||||
|
// -------------------------
|
||||||
|
// Important
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "백본 네트워크", example = "large")
|
||||||
|
private String backbone; // backbone
|
||||||
|
|
||||||
|
@Schema(description = "입력 이미지 크기(H,W)", example = "512,512")
|
||||||
|
private String inputSize; // input_size
|
||||||
|
|
||||||
|
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
|
||||||
|
private String cropSize; // crop_size
|
||||||
|
|
||||||
|
@Schema(description = "배치 크기(Per GPU)", example = "16")
|
||||||
|
private Integer batchSize; // batch_size
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Data
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "Train dataloader workers", example = "16")
|
||||||
|
private Integer trainNumWorkers; // train_num_workers
|
||||||
|
|
||||||
|
@Schema(description = "Val dataloader workers", example = "8")
|
||||||
|
private Integer valNumWorkers; // val_num_workers
|
||||||
|
|
||||||
|
@Schema(description = "Test dataloader workers", example = "8")
|
||||||
|
private Integer testNumWorkers; // test_num_workers
|
||||||
|
|
||||||
|
@Schema(description = "Train shuffle 여부", example = "true")
|
||||||
|
private Boolean trainShuffle; // train_shuffle
|
||||||
|
|
||||||
|
@Schema(description = "Train persistent workers 여부", example = "true")
|
||||||
|
private Boolean trainPersistent; // train_persistent
|
||||||
|
|
||||||
|
@Schema(description = "Val persistent workers 여부", example = "true")
|
||||||
|
private Boolean valPersistent; // val_persistent
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Model Architecture
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "Drop Path 비율", example = "0.3")
|
||||||
|
private Double dropPathRate; // drop_path_rate
|
||||||
|
|
||||||
|
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
|
||||||
|
private Integer frozenStages; // frozen_stages
|
||||||
|
|
||||||
|
@Schema(description = "Neck 결합 정책", example = "abs_diff")
|
||||||
|
private String neckPolicy; // neck_policy
|
||||||
|
|
||||||
|
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
|
||||||
|
private String decoderChannels; // decoder_channels
|
||||||
|
|
||||||
|
@Schema(description = "클래스별 가중치", example = "1,10")
|
||||||
|
private String classWeight; // class_weight
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Loss & Optimization
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "학습률", example = "0.00006")
|
||||||
|
private Double learningRate; // learning_rate
|
||||||
|
|
||||||
|
@Schema(description = "Weight Decay", example = "0.05")
|
||||||
|
private Double weightDecay; // weight_decay
|
||||||
|
|
||||||
|
@Schema(description = "Layer Decay Rate", example = "0.9")
|
||||||
|
private Double layerDecayRate; // layer_decay_rate
|
||||||
|
|
||||||
|
@Schema(description = "DDP unused params 탐색 여부", example = "true")
|
||||||
|
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
|
||||||
|
|
||||||
|
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
|
||||||
|
private Integer ignoreIndex; // ignore_index
|
||||||
|
|
||||||
|
@Schema(description = "레이어 깊이", example = "24")
|
||||||
|
private Integer numLayers; // num_layers
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Evaluation
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
|
||||||
|
private String metrics; // metrics
|
||||||
|
|
||||||
|
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
|
||||||
|
private String saveBest; // save_best
|
||||||
|
|
||||||
|
@Schema(description = "Best 모델 선정 규칙", example = "less")
|
||||||
|
private String saveBestRule; // save_best_rule
|
||||||
|
|
||||||
|
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
|
||||||
|
private Integer valInterval; // val_interval
|
||||||
|
|
||||||
|
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
|
||||||
|
private Integer logInterval; // log_interval
|
||||||
|
|
||||||
|
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
|
||||||
|
private Integer visInterval; // vis_interval
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Augmentation
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "회전 적용 확률", example = "0.5")
|
||||||
|
private Double rotProb; // rot_prob
|
||||||
|
|
||||||
|
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
|
||||||
|
private String rotDegree; // rot_degree
|
||||||
|
|
||||||
|
@Schema(description = "반전 적용 확률", example = "0.5")
|
||||||
|
private Double flipProb; // flip_prob
|
||||||
|
|
||||||
|
@Schema(description = "채널 교환 확률", example = "0.5")
|
||||||
|
private Double exchangeProb; // exchange_prob
|
||||||
|
|
||||||
|
@Schema(description = "밝기 변화량", example = "10")
|
||||||
|
private Integer brightnessDelta; // brightness_delta
|
||||||
|
|
||||||
|
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
|
||||||
|
private String contrastRange; // contrast_range
|
||||||
|
|
||||||
|
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
|
||||||
|
private String saturationRange; // saturation_range
|
||||||
|
|
||||||
|
@Schema(description = "색조 변화량", example = "10")
|
||||||
|
private Integer hueDelta; // hue_delta
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Hardware
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "사용 GPU 개수", example = "4")
|
||||||
|
private Integer gpuCnt; // gpu_cnt
|
||||||
|
|
||||||
|
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
|
||||||
|
private String gpuIds; // gpu_ids
|
||||||
|
|
||||||
|
@Schema(description = "분산학습 마스터 포트", example = "1122")
|
||||||
|
private Integer masterPort; // master_port
|
||||||
|
|
||||||
|
// -------------------------
|
||||||
|
// Memo
|
||||||
|
// -------------------------
|
||||||
|
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
|
||||||
|
private String memo; // memo
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package com.kamco.cd.training.common.enums;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.utils.enums.CodeExpose;
|
||||||
|
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@CodeExpose
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public enum ModelType implements EnumType {
|
||||||
|
M1("M1"),
|
||||||
|
M2("M2"),
|
||||||
|
M3("M3");
|
||||||
|
|
||||||
|
private String desc;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getId() {
|
||||||
|
return name();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getText() {
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.kamco.cd.training.hyperparam;
|
package com.kamco.cd.training.hyperparam;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
||||||
@@ -49,8 +50,7 @@ public class HyperParamApiController {
|
|||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@PostMapping
|
@PostMapping
|
||||||
public ApiResponseDto<String> createHyperParam(
|
public ApiResponseDto<String> createHyperParam(@Valid @RequestBody HyperParam createReq) {
|
||||||
@Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) {
|
|
||||||
String newVersion = hyperParamService.createHyperParam(createReq);
|
String newVersion = hyperParamService.createHyperParam(createReq);
|
||||||
return ApiResponseDto.ok(newVersion);
|
return ApiResponseDto.ok(newVersion);
|
||||||
}
|
}
|
||||||
@@ -70,7 +70,7 @@ public class HyperParamApiController {
|
|||||||
})
|
})
|
||||||
@PutMapping("/{uuid}")
|
@PutMapping("/{uuid}")
|
||||||
public ApiResponseDto<String> updateHyperParam(
|
public ApiResponseDto<String> updateHyperParam(
|
||||||
@PathVariable UUID uuid, @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) {
|
@PathVariable UUID uuid, @Valid @RequestBody HyperParam createReq) {
|
||||||
return ApiResponseDto.ok(hyperParamService.updateHyperParam(uuid, createReq));
|
return ApiResponseDto.ok(hyperParamService.updateHyperParam(uuid, createReq));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,155 +115,6 @@ public class HyperParamDto {
|
|||||||
private Long totalCnt;
|
private Long totalCnt;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public static class HyperParamCreateReq {
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Important
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "백본 네트워크", example = "large")
|
|
||||||
private String backbone; // backbone
|
|
||||||
|
|
||||||
@Schema(description = "입력 이미지 크기(H,W)", example = "512,512")
|
|
||||||
private String inputSize; // input_size
|
|
||||||
|
|
||||||
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
|
|
||||||
private String cropSize; // crop_size
|
|
||||||
|
|
||||||
@Schema(description = "배치 크기(Per GPU)", example = "16")
|
|
||||||
private Integer batchSize; // batch_size
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Data
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "Train dataloader workers", example = "16")
|
|
||||||
private Integer trainNumWorkers; // train_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Val dataloader workers", example = "8")
|
|
||||||
private Integer valNumWorkers; // val_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Test dataloader workers", example = "8")
|
|
||||||
private Integer testNumWorkers; // test_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Train shuffle 여부", example = "true")
|
|
||||||
private Boolean trainShuffle; // train_shuffle
|
|
||||||
|
|
||||||
@Schema(description = "Train persistent workers 여부", example = "true")
|
|
||||||
private Boolean trainPersistent; // train_persistent
|
|
||||||
|
|
||||||
@Schema(description = "Val persistent workers 여부", example = "true")
|
|
||||||
private Boolean valPersistent; // val_persistent
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Model Architecture
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "Drop Path 비율", example = "0.3")
|
|
||||||
private Double dropPathRate; // drop_path_rate
|
|
||||||
|
|
||||||
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
|
|
||||||
private Integer frozenStages; // frozen_stages
|
|
||||||
|
|
||||||
@Schema(description = "Neck 결합 정책", example = "abs_diff")
|
|
||||||
private String neckPolicy; // neck_policy
|
|
||||||
|
|
||||||
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
|
|
||||||
private String decoderChannels; // decoder_channels
|
|
||||||
|
|
||||||
@Schema(description = "클래스별 가중치", example = "1,10")
|
|
||||||
private String classWeight; // class_weight
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Loss & Optimization
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "학습률", example = "0.00006")
|
|
||||||
private Double learningRate; // learning_rate
|
|
||||||
|
|
||||||
@Schema(description = "Weight Decay", example = "0.05")
|
|
||||||
private Double weightDecay; // weight_decay
|
|
||||||
|
|
||||||
@Schema(description = "Layer Decay Rate", example = "0.9")
|
|
||||||
private Double layerDecayRate; // layer_decay_rate
|
|
||||||
|
|
||||||
@Schema(description = "DDP unused params 탐색 여부", example = "true")
|
|
||||||
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
|
|
||||||
|
|
||||||
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
|
|
||||||
private Integer ignoreIndex; // ignore_index
|
|
||||||
|
|
||||||
@Schema(description = "레이어 깊이", example = "24")
|
|
||||||
private Integer numLayers; // num_layers
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Evaluation
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
|
|
||||||
private String metrics; // metrics
|
|
||||||
|
|
||||||
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
|
|
||||||
private String saveBest; // save_best
|
|
||||||
|
|
||||||
@Schema(description = "Best 모델 선정 규칙", example = "less")
|
|
||||||
private String saveBestRule; // save_best_rule
|
|
||||||
|
|
||||||
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
|
|
||||||
private Integer valInterval; // val_interval
|
|
||||||
|
|
||||||
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
|
|
||||||
private Integer logInterval; // log_interval
|
|
||||||
|
|
||||||
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
|
|
||||||
private Integer visInterval; // vis_interval
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Augmentation
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "회전 적용 확률", example = "0.5")
|
|
||||||
private Double rotProb; // rot_prob
|
|
||||||
|
|
||||||
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
|
|
||||||
private String rotDegree; // rot_degree
|
|
||||||
|
|
||||||
@Schema(description = "반전 적용 확률", example = "0.5")
|
|
||||||
private Double flipProb; // flip_prob
|
|
||||||
|
|
||||||
@Schema(description = "채널 교환 확률", example = "0.5")
|
|
||||||
private Double exchangeProb; // exchange_prob
|
|
||||||
|
|
||||||
@Schema(description = "밝기 변화량", example = "10")
|
|
||||||
private Integer brightnessDelta; // brightness_delta
|
|
||||||
|
|
||||||
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
|
|
||||||
private String contrastRange; // contrast_range
|
|
||||||
|
|
||||||
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
|
|
||||||
private String saturationRange; // saturation_range
|
|
||||||
|
|
||||||
@Schema(description = "색조 변화량", example = "10")
|
|
||||||
private Integer hueDelta; // hue_delta
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Hardware
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "사용 GPU 개수", example = "4")
|
|
||||||
private Integer gpuCnt; // gpu_cnt
|
|
||||||
|
|
||||||
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
|
|
||||||
private String gpuIds; // gpu_ids
|
|
||||||
|
|
||||||
@Schema(description = "분산학습 마스터 포트", example = "1122")
|
|
||||||
private Integer masterPort; // master_port
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Memo
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
|
|
||||||
private String memo; // memo
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.kamco.cd.training.hyperparam.service;
|
package com.kamco.cd.training.hyperparam.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
||||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||||
@@ -33,8 +34,8 @@ public class HyperParamService {
|
|||||||
* @return 생성된 버전명
|
* @return 생성된 버전명
|
||||||
*/
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
|
public String createHyperParam(HyperParam createReq) {
|
||||||
return hyperParamCoreService.createHyperParam(createReq);
|
return hyperParamCoreService.createHyperParam(createReq).getHyperVer();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -44,7 +45,7 @@ public class HyperParamService {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
|
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||||
return hyperParamCoreService.updateHyperParam(uuid, createReq);
|
return hyperParamCoreService.updateHyperParam(uuid, createReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package com.kamco.cd.training.model;
|
package com.kamco.cd.training.model;
|
||||||
|
|
||||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto.Basic;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||||
import com.kamco.cd.training.model.service.ModelMngService;
|
import com.kamco.cd.training.model.service.ModelTrainMngService;
|
||||||
import com.kamco.cd.training.model.service.ModelTrainService;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
import io.swagger.v3.oas.annotations.media.Content;
|
import io.swagger.v3.oas.annotations.media.Content;
|
||||||
@@ -28,9 +27,8 @@ import org.springframework.web.bind.annotation.RestController;
|
|||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@Tag(name = "모델학습 관리", description = "어드민 홈 > 모델학습관리 > 모델관리 > 목록")
|
@Tag(name = "모델학습 관리", description = "어드민 홈 > 모델학습관리 > 모델관리 > 목록")
|
||||||
@RequestMapping("/api/models")
|
@RequestMapping("/api/models")
|
||||||
public class ModelMngApiController {
|
public class ModelTrainMngApiController {
|
||||||
private final ModelMngService modelMngService;
|
private final ModelTrainMngService modelTrainMngService;
|
||||||
private final ModelTrainService modelTrainService;
|
|
||||||
|
|
||||||
@Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API")
|
@Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API")
|
||||||
@ApiResponses(
|
@ApiResponses(
|
||||||
@@ -55,8 +53,8 @@ public class ModelMngApiController {
|
|||||||
String status,
|
String status,
|
||||||
@Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page,
|
@Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page,
|
||||||
@Parameter(description = "페이지 크기") @RequestParam(defaultValue = "20") int size) {
|
@Parameter(description = "페이지 크기") @RequestParam(defaultValue = "20") int size) {
|
||||||
ModelTrainDto.SearchReq searchReq = new ModelTrainDto.SearchReq(status, page, size);
|
ModelTrainMngDto.SearchReq searchReq = new ModelTrainMngDto.SearchReq(status, page, size);
|
||||||
return ApiResponseDto.ok(modelMngService.getModelList(searchReq));
|
return ApiResponseDto.ok(modelTrainMngService.getModelList(searchReq));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "학습 모델 삭제", description = "학습 모델 삭제 API")
|
@Operation(summary = "학습 모델 삭제", description = "학습 모델 삭제 API")
|
||||||
@@ -70,7 +68,7 @@ public class ModelMngApiController {
|
|||||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||||
@PathVariable
|
@PathVariable
|
||||||
UUID uuid) {
|
UUID uuid) {
|
||||||
modelMngService.deleteModelTrain(uuid);
|
modelTrainMngService.deleteModelTrain(uuid);
|
||||||
return ApiResponseDto.ok(null);
|
return ApiResponseDto.ok(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,8 +79,9 @@ public class ModelMngApiController {
|
|||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@PostMapping
|
@PostMapping
|
||||||
public ApiResponseDto<String> createModelTrain(@RequestBody ModelTrainDto.AddReq modelTrainDto) {
|
public ApiResponseDto<String> createModelTrain(@RequestBody ModelTrainMngDto.AddReq req) {
|
||||||
return ApiResponseDto.ok(null);
|
modelTrainMngService.createModelTrain(req);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -104,7 +103,7 @@ public class ModelMngApiController {
|
|||||||
// @Parameter(description = "모델 UUID", example = "b7e99739-6736-45f9-a224-8161ecddf287")
|
// @Parameter(description = "모델 UUID", example = "b7e99739-6736-45f9-a224-8161ecddf287")
|
||||||
// @PathVariable
|
// @PathVariable
|
||||||
// String uuid) {
|
// String uuid) {
|
||||||
// return ApiResponseDto.ok(modelMngService.getModelDetailByUuid(uuid));
|
// return ApiResponseDto.ok(modelTrainMngService.getModelDetailByUuid(uuid));
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// // ==================== 학습 모델학습관리 API (5종) ====================
|
// // ==================== 학습 모델학습관리 API (5종) ====================
|
||||||
@@ -1,336 +0,0 @@
|
|||||||
package com.kamco.cd.training.model.dto;
|
|
||||||
|
|
||||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
|
||||||
import com.kamco.cd.training.common.enums.TrainType;
|
|
||||||
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
|
||||||
import jakarta.validation.constraints.NotNull;
|
|
||||||
import java.time.Duration;
|
|
||||||
import java.time.ZonedDateTime;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.data.domain.PageRequest;
|
|
||||||
import org.springframework.data.domain.Pageable;
|
|
||||||
|
|
||||||
public class ModelTrainDto {
|
|
||||||
@Schema(name = "모델학습관리 목록", description = "모델학습관리 목록")
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
public static class Basic {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
private UUID uuid;
|
|
||||||
private String modelVer;
|
|
||||||
@JsonFormatDttm private ZonedDateTime startDttm;
|
|
||||||
@JsonFormatDttm private ZonedDateTime step1StrtDttm;
|
|
||||||
@JsonFormatDttm private ZonedDateTime step1EndDttm;
|
|
||||||
@JsonFormatDttm private ZonedDateTime step2StrtDttm;
|
|
||||||
@JsonFormatDttm private ZonedDateTime step2EndDttm;
|
|
||||||
private String step1Status;
|
|
||||||
private String step2Status;
|
|
||||||
private String statusCd;
|
|
||||||
private String trainType;
|
|
||||||
|
|
||||||
public String getStatusName() {
|
|
||||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
|
||||||
try {
|
|
||||||
return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName()
|
|
||||||
} catch (IllegalArgumentException e) {
|
|
||||||
return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getStep1StatusName() {
|
|
||||||
if (this.step1Status == null || this.step1Status.isBlank()) return null;
|
|
||||||
try {
|
|
||||||
return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName()
|
|
||||||
} catch (IllegalArgumentException e) {
|
|
||||||
return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getStep2StatusNAme() {
|
|
||||||
if (this.step2Status == null || this.step2Status.isBlank()) return null;
|
|
||||||
try {
|
|
||||||
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
|
|
||||||
} catch (IllegalArgumentException e) {
|
|
||||||
return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getTrainTypeName() {
|
|
||||||
if (this.trainType == null || this.trainType.isBlank()) return null;
|
|
||||||
try {
|
|
||||||
return TrainType.valueOf(this.trainType).getText(); // 또는 getName()
|
|
||||||
} catch (IllegalArgumentException e) {
|
|
||||||
return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String formatDuration(ZonedDateTime start, ZonedDateTime end) {
|
|
||||||
if (start == null || end == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
long totalSeconds = Math.abs(Duration.between(start, end).getSeconds());
|
|
||||||
|
|
||||||
long hours = totalSeconds / 3600;
|
|
||||||
long minutes = (totalSeconds % 3600) / 60;
|
|
||||||
long seconds = totalSeconds % 60;
|
|
||||||
|
|
||||||
return String.format("%d시간 %d분 %d초", hours, minutes, seconds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getStep1Duration() {
|
|
||||||
return formatDuration(this.step1StrtDttm, this.step1EndDttm);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getStep2Duration() {
|
|
||||||
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터")
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public static class SearchReq {
|
|
||||||
|
|
||||||
private String status;
|
|
||||||
// 페이징 파라미터
|
|
||||||
private int page = 0;
|
|
||||||
private int size = 20;
|
|
||||||
|
|
||||||
public Pageable toPageable() {
|
|
||||||
return PageRequest.of(page, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public static class AddReq {
|
|
||||||
HyperParamDto hyperParam;
|
|
||||||
TrainingDataConfigDto trainingDataConfig;
|
|
||||||
EtcConfig etcConfig;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 학습실행설정 하이퍼파라미터 설정 */
|
|
||||||
@Schema(name = "하이퍼파라미터 설정", description = "학습실행설정 > 하이퍼파라미터 설정")
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class HyperParamDto {
|
|
||||||
|
|
||||||
@NotNull
|
|
||||||
@Schema(
|
|
||||||
description = "OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
|
|
||||||
example = "EXISTING")
|
|
||||||
private String hyperParamType;
|
|
||||||
|
|
||||||
@Schema(description = "기존 파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
|
|
||||||
private String hyperUuid;
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Important
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "백본 네트워크", example = "large")
|
|
||||||
private String backbone; // backbone
|
|
||||||
|
|
||||||
@Schema(description = "입력 이미지 크기(H,W)", example = "512,512")
|
|
||||||
private String inputSize; // input_size
|
|
||||||
|
|
||||||
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
|
|
||||||
private String cropSize; // crop_size
|
|
||||||
|
|
||||||
@Schema(description = "배치 크기(Per GPU)", example = "16")
|
|
||||||
private Integer batchSize; // batch_size
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Data
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "Train dataloader workers", example = "16")
|
|
||||||
private Integer trainNumWorkers; // train_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Val dataloader workers", example = "8")
|
|
||||||
private Integer valNumWorkers; // val_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Test dataloader workers", example = "8")
|
|
||||||
private Integer testNumWorkers; // test_num_workers
|
|
||||||
|
|
||||||
@Schema(description = "Train shuffle 여부", example = "true")
|
|
||||||
private Boolean trainShuffle; // train_shuffle
|
|
||||||
|
|
||||||
@Schema(description = "Train persistent workers 여부", example = "true")
|
|
||||||
private Boolean trainPersistent; // train_persistent
|
|
||||||
|
|
||||||
@Schema(description = "Val persistent workers 여부", example = "true")
|
|
||||||
private Boolean valPersistent; // val_persistent
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Model Architecture
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "Drop Path 비율", example = "0.3")
|
|
||||||
private Double dropPathRate; // drop_path_rate
|
|
||||||
|
|
||||||
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
|
|
||||||
private Integer frozenStages; // frozen_stages
|
|
||||||
|
|
||||||
@Schema(description = "Neck 결합 정책", example = "abs_diff")
|
|
||||||
private String neckPolicy; // neck_policy
|
|
||||||
|
|
||||||
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
|
|
||||||
private String decoderChannels; // decoder_channels
|
|
||||||
|
|
||||||
@Schema(description = "클래스별 가중치", example = "1,10")
|
|
||||||
private String classWeight; // class_weight
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Loss & Optimization
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "학습률", example = "0.00006")
|
|
||||||
private Double learningRate; // learning_rate
|
|
||||||
|
|
||||||
@Schema(description = "Weight Decay", example = "0.05")
|
|
||||||
private Double weightDecay; // weight_decay
|
|
||||||
|
|
||||||
@Schema(description = "Layer Decay Rate", example = "0.9")
|
|
||||||
private Double layerDecayRate; // layer_decay_rate
|
|
||||||
|
|
||||||
@Schema(description = "DDP unused params 탐색 여부", example = "true")
|
|
||||||
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
|
|
||||||
|
|
||||||
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
|
|
||||||
private Integer ignoreIndex; // ignore_index
|
|
||||||
|
|
||||||
@Schema(description = "레이어 깊이", example = "24")
|
|
||||||
private Integer numLayers; // num_layers
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Evaluation
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
|
|
||||||
private String metrics; // metrics
|
|
||||||
|
|
||||||
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
|
|
||||||
private String saveBest; // save_best
|
|
||||||
|
|
||||||
@Schema(description = "Best 모델 선정 규칙", example = "less")
|
|
||||||
private String saveBestRule; // save_best_rule
|
|
||||||
|
|
||||||
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
|
|
||||||
private Integer valInterval; // val_interval
|
|
||||||
|
|
||||||
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
|
|
||||||
private Integer logInterval; // log_interval
|
|
||||||
|
|
||||||
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
|
|
||||||
private Integer visInterval; // vis_interval
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Augmentation
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "회전 적용 확률", example = "0.5")
|
|
||||||
private Double rotProb; // rot_prob
|
|
||||||
|
|
||||||
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
|
|
||||||
private String rotDegree; // rot_degree
|
|
||||||
|
|
||||||
@Schema(description = "반전 적용 확률", example = "0.5")
|
|
||||||
private Double flipProb; // flip_prob
|
|
||||||
|
|
||||||
@Schema(description = "채널 교환 확률", example = "0.5")
|
|
||||||
private Double exchangeProb; // exchange_prob
|
|
||||||
|
|
||||||
@Schema(description = "밝기 변화량", example = "10")
|
|
||||||
private Integer brightnessDelta; // brightness_delta
|
|
||||||
|
|
||||||
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
|
|
||||||
private String contrastRange; // contrast_range
|
|
||||||
|
|
||||||
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
|
|
||||||
private String saturationRange; // saturation_range
|
|
||||||
|
|
||||||
@Schema(description = "색조 변화량", example = "10")
|
|
||||||
private Integer hueDelta; // hue_delta
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Hardware
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "사용 GPU 개수", example = "4")
|
|
||||||
private Integer gpuCnt; // gpu_cnt
|
|
||||||
|
|
||||||
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
|
|
||||||
private String gpuIds; // gpu_ids
|
|
||||||
|
|
||||||
@Schema(description = "분산학습 마스터 포트", example = "1122")
|
|
||||||
private Integer masterPort; // master_port
|
|
||||||
|
|
||||||
// -------------------------
|
|
||||||
// Memo
|
|
||||||
// -------------------------
|
|
||||||
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
|
|
||||||
private String memo; // memo
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class TrainingDataConfigDto {
|
|
||||||
Summary summary;
|
|
||||||
List<Dataset> datasetList;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class Summary {
|
|
||||||
@Schema(description = "건물", example = "0")
|
|
||||||
private Long buildingCnt;
|
|
||||||
|
|
||||||
@Schema(description = "컨테이너", example = "0")
|
|
||||||
private Long containerCnt;
|
|
||||||
|
|
||||||
@Schema(description = "폐기물", example = "0")
|
|
||||||
private Long wasteCnt;
|
|
||||||
|
|
||||||
@Schema(
|
|
||||||
description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 태양광, 기타",
|
|
||||||
example = "0")
|
|
||||||
private Long LandCoverCnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class Dataset {
|
|
||||||
@Schema(description = "데이터셋 uuid")
|
|
||||||
private UUID uuid;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class EtcConfig {
|
|
||||||
@Schema(description = "에폭 횟수", example = "0")
|
|
||||||
private Long epochCnt;
|
|
||||||
|
|
||||||
@Schema(description = "학습데이터셋 비율 Training", example = "0")
|
|
||||||
private Integer trainingCnt;
|
|
||||||
|
|
||||||
@Schema(description = "학습데이터셋 비율 Validation", example = "0")
|
|
||||||
private Integer validationCnt;
|
|
||||||
|
|
||||||
@Schema(description = "학습데이터셋 비율 Test", example = "0")
|
|
||||||
private Integer testCnt;
|
|
||||||
|
|
||||||
@Schema(description = "메모", example = "메모 입니다.")
|
|
||||||
private String memo;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,195 @@
|
|||||||
|
package com.kamco.cd.training.model.dto;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainType;
|
||||||
|
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import jakarta.validation.constraints.NotNull;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.ZonedDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.springframework.data.domain.PageRequest;
|
||||||
|
import org.springframework.data.domain.Pageable;
|
||||||
|
|
||||||
|
public class ModelTrainMngDto {
|
||||||
|
@Schema(name = "모델학습관리 목록", description = "모델학습관리 목록")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Builder
|
||||||
|
public static class Basic {
|
||||||
|
|
||||||
|
private Long id;
|
||||||
|
private UUID uuid;
|
||||||
|
private String modelVer;
|
||||||
|
@JsonFormatDttm private ZonedDateTime startDttm;
|
||||||
|
@JsonFormatDttm private ZonedDateTime step1StrtDttm;
|
||||||
|
@JsonFormatDttm private ZonedDateTime step1EndDttm;
|
||||||
|
@JsonFormatDttm private ZonedDateTime step2StrtDttm;
|
||||||
|
@JsonFormatDttm private ZonedDateTime step2EndDttm;
|
||||||
|
private String step1Status;
|
||||||
|
private String step2Status;
|
||||||
|
private String statusCd;
|
||||||
|
private String trainType;
|
||||||
|
|
||||||
|
public String getStatusName() {
|
||||||
|
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||||
|
try {
|
||||||
|
return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName()
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getStep1StatusName() {
|
||||||
|
if (this.step1Status == null || this.step1Status.isBlank()) return null;
|
||||||
|
try {
|
||||||
|
return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName()
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getStep2StatusNAme() {
|
||||||
|
if (this.step2Status == null || this.step2Status.isBlank()) return null;
|
||||||
|
try {
|
||||||
|
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTrainTypeName() {
|
||||||
|
if (this.trainType == null || this.trainType.isBlank()) return null;
|
||||||
|
try {
|
||||||
|
return TrainType.valueOf(this.trainType).getText(); // 또는 getName()
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String formatDuration(ZonedDateTime start, ZonedDateTime end) {
|
||||||
|
if (start == null || end == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
long totalSeconds = Math.abs(Duration.between(start, end).getSeconds());
|
||||||
|
|
||||||
|
long hours = totalSeconds / 3600;
|
||||||
|
long minutes = (totalSeconds % 3600) / 60;
|
||||||
|
long seconds = totalSeconds % 60;
|
||||||
|
|
||||||
|
return String.format("%d시간 %d분 %d초", hours, minutes, seconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getStep1Duration() {
|
||||||
|
return formatDuration(this.step1StrtDttm, this.step1EndDttm);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getStep2Duration() {
|
||||||
|
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public static class SearchReq {
|
||||||
|
|
||||||
|
private String status;
|
||||||
|
// 페이징 파라미터
|
||||||
|
private int page = 0;
|
||||||
|
private int size = 20;
|
||||||
|
|
||||||
|
public Pageable toPageable() {
|
||||||
|
return PageRequest.of(page, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public static class AddReq {
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Schema(description = "모델 종류 M1, M2, M3", example = "M1")
|
||||||
|
private String modelNo;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Schema(description = "모델학습 실행 여부", example = "false")
|
||||||
|
private Boolean isStart;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
|
||||||
|
private String trainType;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
@Schema(
|
||||||
|
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
|
||||||
|
example = "EXISTING")
|
||||||
|
private String hyperParamType;
|
||||||
|
|
||||||
|
@Schema(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
|
||||||
|
private UUID hyperUuid;
|
||||||
|
|
||||||
|
HyperParam hyperParam;
|
||||||
|
TrainingDataset trainingDataset;
|
||||||
|
ModelConfig modelConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
public static class TrainingDataset {
|
||||||
|
Summary summary;
|
||||||
|
List<Long> datasetList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
public static class Summary {
|
||||||
|
@Schema(description = "건물", example = "0")
|
||||||
|
private Long buildingCnt;
|
||||||
|
|
||||||
|
@Schema(description = "컨테이너", example = "0")
|
||||||
|
private Long containerCnt;
|
||||||
|
|
||||||
|
@Schema(description = "폐기물", example = "0")
|
||||||
|
private Long wasteCnt;
|
||||||
|
|
||||||
|
@Schema(
|
||||||
|
description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 태양광, 기타",
|
||||||
|
example = "0")
|
||||||
|
private Long LandCoverCnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
public static class ModelConfig {
|
||||||
|
@Schema(description = "에폭 횟수", example = "0")
|
||||||
|
private Integer epochCnt;
|
||||||
|
|
||||||
|
@Schema(description = "학습데이터셋 비율 Training", example = "0")
|
||||||
|
private Float trainingCnt;
|
||||||
|
|
||||||
|
@Schema(description = "학습데이터셋 비율 Validation", example = "0")
|
||||||
|
private Float validationCnt;
|
||||||
|
|
||||||
|
@Schema(description = "학습데이터셋 비율 Test", example = "0")
|
||||||
|
private Float testCnt;
|
||||||
|
|
||||||
|
@Schema(description = "메모", example = "메모 입니다.")
|
||||||
|
private String memo;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
package com.kamco.cd.training.model.service;
|
|
||||||
|
|
||||||
import com.kamco.cd.training.model.dto.ModelMngDto;
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto.SearchReq;
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelMngCoreService;
|
|
||||||
import java.util.UUID;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.data.domain.Page;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
@Transactional(readOnly = true)
|
|
||||||
@Slf4j
|
|
||||||
public class ModelMngService {
|
|
||||||
|
|
||||||
private final ModelMngCoreService modelMngCoreService;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 모델 목록 조회
|
|
||||||
*
|
|
||||||
* @param searchReq 검색 조건
|
|
||||||
* @return 페이징 처리된 모델 목록
|
|
||||||
*/
|
|
||||||
public Page<ModelTrainDto.Basic> getModelList(SearchReq searchReq) {
|
|
||||||
return modelMngCoreService.findByModelList(searchReq);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 학습모델 삭제
|
|
||||||
*
|
|
||||||
* @param uuid
|
|
||||||
*/
|
|
||||||
public void deleteModelTrain(UUID uuid) {
|
|
||||||
modelMngCoreService.deleteModel(uuid);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String createModelTrain(ModelMngDto modelMngDto) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 모델 상세 조회
|
|
||||||
*
|
|
||||||
* @param modelUid 모델 UID
|
|
||||||
* @return 모델 상세 정보
|
|
||||||
*/
|
|
||||||
public ModelMngDto.Detail getModelDetail(Long modelUid) {
|
|
||||||
return modelMngCoreService.getModelDetail(modelUid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package com.kamco.cd.training.model.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
|
import com.kamco.cd.training.common.enums.HyperParamSelectType;
|
||||||
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelMngDto;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||||
|
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import java.util.UUID;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.data.domain.Page;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
@Transactional(readOnly = true)
|
||||||
|
@Slf4j
|
||||||
|
public class ModelTrainMngService {
|
||||||
|
|
||||||
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
private final HyperParamCoreService hyperParamCoreService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 조회
|
||||||
|
*
|
||||||
|
* @param searchReq 검색 조건
|
||||||
|
* @return 페이징 처리된 모델 목록
|
||||||
|
*/
|
||||||
|
public Page<ModelTrainMngDto.Basic> getModelList(SearchReq searchReq) {
|
||||||
|
return modelTrainMngCoreService.findByModelList(searchReq);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 삭제
|
||||||
|
*
|
||||||
|
* @param uuid
|
||||||
|
*/
|
||||||
|
public void deleteModelTrain(UUID uuid) {
|
||||||
|
modelTrainMngCoreService.deleteModel(uuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 등록
|
||||||
|
*
|
||||||
|
* @param req
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public void createModelTrain(ModelTrainMngDto.AddReq req) {
|
||||||
|
HyperParam hyperParam = req.getHyperParam();
|
||||||
|
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||||
|
|
||||||
|
/** OPTIMIZED(최적화 파라미터), EXISTING(기존 파라미터), NEW(신규 파라미터) * */
|
||||||
|
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
|
||||||
|
// 하이퍼파라미터 등록
|
||||||
|
hyper = hyperParamCoreService.createHyperParam(hyperParam);
|
||||||
|
req.setHyperUuid(hyper.getUuid());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 모델학습 테이블 저장
|
||||||
|
Long modelId = modelTrainMngCoreService.saveModel(req);
|
||||||
|
|
||||||
|
// 모델학습 데이터셋 저장
|
||||||
|
modelTrainMngCoreService.saveModelDataset(modelId, req);
|
||||||
|
|
||||||
|
// 모델 데이터셋 mapping 저장
|
||||||
|
modelTrainMngCoreService.saveModelDatasetMap(
|
||||||
|
modelId, req.getTrainingDataset().getDatasetList());
|
||||||
|
|
||||||
|
// 모델 config 저장
|
||||||
|
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 상세 조회
|
||||||
|
*
|
||||||
|
* @param modelUid 모델 UID
|
||||||
|
* @return 모델 상세 정보
|
||||||
|
*/
|
||||||
|
public ModelMngDto.Detail getModelDetail(Long modelUid) {
|
||||||
|
return modelTrainMngCoreService.getModelDetail(modelUid);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ import com.kamco.cd.training.common.exception.BadRequestException;
|
|||||||
import com.kamco.cd.training.common.exception.NotFoundException;
|
import com.kamco.cd.training.common.exception.NotFoundException;
|
||||||
import com.kamco.cd.training.model.dto.ModelMngDto;
|
import com.kamco.cd.training.model.dto.ModelMngDto;
|
||||||
import com.kamco.cd.training.postgres.core.DatasetCoreService;
|
import com.kamco.cd.training.postgres.core.DatasetCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.SystemMetricsCoreService;
|
import com.kamco.cd.training.postgres.core.SystemMetricsCoreService;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -19,7 +19,7 @@ import org.springframework.transaction.annotation.Transactional;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class ModelTrainService {
|
public class ModelTrainService {
|
||||||
|
|
||||||
private final ModelMngCoreService modelMngCoreService;
|
private final ModelTrainMngCoreService modelMngCoreService;
|
||||||
private final DatasetCoreService datasetCoreService;
|
private final DatasetCoreService datasetCoreService;
|
||||||
private final SystemMetricsCoreService systemMetricsCoreService;
|
private final SystemMetricsCoreService systemMetricsCoreService;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.dto.HyperParam;
|
||||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
import com.kamco.cd.training.common.utils.UserUtil;
|
import com.kamco.cd.training.common.utils.UserUtil;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||||
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
@@ -21,10 +23,10 @@ public class HyperParamCoreService {
|
|||||||
/**
|
/**
|
||||||
* 하이퍼파라미터 등록
|
* 하이퍼파라미터 등록
|
||||||
*
|
*
|
||||||
* @param createReq 등록 요청
|
* @param createReq ModelTrainMngDto.HyperParamDto
|
||||||
* @return 등록된 버전명
|
* @return 등록된 버전명
|
||||||
*/
|
*/
|
||||||
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
|
public Basic createHyperParam(HyperParam createReq) {
|
||||||
String firstVersion = getFirstHyperParamVersion();
|
String firstVersion = getFirstHyperParamVersion();
|
||||||
|
|
||||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||||
@@ -36,7 +38,10 @@ public class HyperParamCoreService {
|
|||||||
entity.setCreatedUid(userUtil.getId());
|
entity.setCreatedUid(userUtil.getId());
|
||||||
|
|
||||||
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
|
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
|
||||||
return resultEntity.getHyperVer();
|
Basic basic = new Basic();
|
||||||
|
basic.setUuid(resultEntity.getUuid());
|
||||||
|
basic.setHyperVer(resultEntity.getHyperVer());
|
||||||
|
return basic;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -46,7 +51,7 @@ public class HyperParamCoreService {
|
|||||||
* @param createReq 등록 요청
|
* @param createReq 등록 요청
|
||||||
* @return ver
|
* @return ver
|
||||||
*/
|
*/
|
||||||
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
|
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
@@ -61,47 +66,46 @@ public class HyperParamCoreService {
|
|||||||
return entity.getHyperVer();
|
return entity.getHyperVer();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void applyHyperParam(
|
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
|
||||||
ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) {
|
|
||||||
// Important
|
// Important
|
||||||
entity.setBackbone(createReq.getBackbone());
|
entity.setBackbone(src.getBackbone());
|
||||||
entity.setInputSize(createReq.getInputSize());
|
entity.setInputSize(src.getInputSize());
|
||||||
entity.setCropSize(createReq.getCropSize());
|
entity.setCropSize(src.getCropSize());
|
||||||
entity.setBatchSize(createReq.getBatchSize());
|
entity.setBatchSize(src.getBatchSize());
|
||||||
|
|
||||||
// Data
|
// Data
|
||||||
entity.setTrainNumWorkers(createReq.getTrainNumWorkers());
|
entity.setTrainNumWorkers(src.getTrainNumWorkers());
|
||||||
entity.setValNumWorkers(createReq.getValNumWorkers());
|
entity.setValNumWorkers(src.getValNumWorkers());
|
||||||
entity.setTestNumWorkers(createReq.getTestNumWorkers());
|
entity.setTestNumWorkers(src.getTestNumWorkers());
|
||||||
entity.setTrainShuffle(createReq.getTrainShuffle());
|
entity.setTrainShuffle(src.getTrainShuffle());
|
||||||
entity.setTrainPersistent(createReq.getTrainPersistent());
|
entity.setTrainPersistent(src.getTrainPersistent());
|
||||||
entity.setValPersistent(createReq.getValPersistent());
|
entity.setValPersistent(src.getValPersistent());
|
||||||
|
|
||||||
// Model Architecture
|
// Model Architecture
|
||||||
entity.setDropPathRate(createReq.getDropPathRate());
|
entity.setDropPathRate(src.getDropPathRate());
|
||||||
entity.setFrozenStages(createReq.getFrozenStages());
|
entity.setFrozenStages(src.getFrozenStages());
|
||||||
entity.setNeckPolicy(createReq.getNeckPolicy());
|
entity.setNeckPolicy(src.getNeckPolicy());
|
||||||
entity.setClassWeight(createReq.getClassWeight());
|
entity.setClassWeight(src.getClassWeight());
|
||||||
entity.setDecoderChannels(createReq.getDecoderChannels());
|
entity.setDecoderChannels(src.getDecoderChannels());
|
||||||
|
|
||||||
// Loss & Optimization
|
// Loss & Optimization
|
||||||
entity.setLearningRate(createReq.getLearningRate());
|
entity.setLearningRate(src.getLearningRate());
|
||||||
entity.setWeightDecay(createReq.getWeightDecay());
|
entity.setWeightDecay(src.getWeightDecay());
|
||||||
entity.setLayerDecayRate(createReq.getLayerDecayRate());
|
entity.setLayerDecayRate(src.getLayerDecayRate());
|
||||||
entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams());
|
entity.setDdpFindUnusedParams(src.getDdpFindUnusedParams());
|
||||||
entity.setIgnoreIndex(createReq.getIgnoreIndex());
|
entity.setIgnoreIndex(src.getIgnoreIndex());
|
||||||
entity.setNumLayers(createReq.getNumLayers());
|
entity.setNumLayers(src.getNumLayers());
|
||||||
|
|
||||||
// Evaluation
|
// Evaluation
|
||||||
entity.setMetrics(createReq.getMetrics());
|
entity.setMetrics(src.getMetrics());
|
||||||
entity.setSaveBest(createReq.getSaveBest());
|
entity.setSaveBest(src.getSaveBest());
|
||||||
entity.setSaveBestRule(createReq.getSaveBestRule());
|
entity.setSaveBestRule(src.getSaveBestRule());
|
||||||
entity.setValInterval(createReq.getValInterval());
|
entity.setValInterval(src.getValInterval());
|
||||||
entity.setLogInterval(createReq.getLogInterval());
|
entity.setLogInterval(src.getLogInterval());
|
||||||
entity.setVisInterval(createReq.getVisInterval());
|
entity.setVisInterval(src.getVisInterval());
|
||||||
|
|
||||||
// memo
|
// memo
|
||||||
entity.setMemo(createReq.getMemo());
|
entity.setMemo(src.getMemo());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.enums.ModelType;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.common.exception.BadRequestException;
|
import com.kamco.cd.training.common.exception.BadRequestException;
|
||||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
import com.kamco.cd.training.common.exception.NotFoundException;
|
import com.kamco.cd.training.common.exception.NotFoundException;
|
||||||
import com.kamco.cd.training.common.utils.UserUtil;
|
import com.kamco.cd.training.common.utils.UserUtil;
|
||||||
import com.kamco.cd.training.model.dto.ModelMngDto;
|
import com.kamco.cd.training.model.dto.ModelMngDto;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto.Basic;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset;
|
||||||
|
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
|
||||||
|
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||||
|
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
|
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||||
|
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||||
|
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -21,9 +30,12 @@ import org.springframework.stereotype.Service;
|
|||||||
|
|
||||||
@Service
|
@Service
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class ModelMngCoreService {
|
public class ModelTrainMngCoreService {
|
||||||
private final ModelMngRepository modelMngRepository;
|
private final ModelMngRepository modelMngRepository;
|
||||||
private final ModelDatasetMappRepository modelDatasetMappRepository;
|
private final ModelDatasetRepository modelDatasetRepository;
|
||||||
|
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||||
|
private final ModelConfigRepository modelConfigRepository;
|
||||||
|
private final HyperParamRepository hyperParamRepository;
|
||||||
private final UserUtil userUtil;
|
private final UserUtil userUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -32,7 +44,7 @@ public class ModelMngCoreService {
|
|||||||
* @param searchReq 검색 조건
|
* @param searchReq 검색 조건
|
||||||
* @return 페이징 처리된 모델 목록
|
* @return 페이징 처리된 모델 목록
|
||||||
*/
|
*/
|
||||||
public Page<Basic> findByModelList(ModelTrainDto.SearchReq searchReq) {
|
public Page<Basic> findByModelList(ModelTrainMngDto.SearchReq searchReq) {
|
||||||
Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
|
Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
|
||||||
return entityPage.map(ModelMasterEntity::toDto);
|
return entityPage.map(ModelMasterEntity::toDto);
|
||||||
}
|
}
|
||||||
@@ -52,6 +64,103 @@ public class ModelMngCoreService {
|
|||||||
entity.setUpdatedUid(userUtil.getId());
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 저장
|
||||||
|
*
|
||||||
|
* @param addReq
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
|
||||||
|
ModelMasterEntity entity = new ModelMasterEntity();
|
||||||
|
ModelHyperParamEntity hyperParamEntity =
|
||||||
|
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||||
|
|
||||||
|
entity.setModelNo(addReq.getModelNo());
|
||||||
|
entity.setTrainType(addReq.getTrainType());
|
||||||
|
|
||||||
|
if (hyperParamEntity != null) {
|
||||||
|
entity.setHyperParamId(hyperParamEntity.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (addReq.getIsStart()) {
|
||||||
|
entity.setModelStep((short) 1);
|
||||||
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
entity.setStrtDttm(ZonedDateTime.now());
|
||||||
|
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||||
|
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
entity.setCreatedUid(userUtil.getId());
|
||||||
|
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
||||||
|
return resultEntity.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* data set 저장
|
||||||
|
*
|
||||||
|
* @param modelId 저장한 모델 학습 id
|
||||||
|
* @param addReq 요청 파라미터
|
||||||
|
*/
|
||||||
|
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||||
|
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||||
|
ModelMasterEntity modelMasterEntity = new ModelMasterEntity();
|
||||||
|
ModelDatasetEntity datasetEntity = new ModelDatasetEntity();
|
||||||
|
|
||||||
|
modelMasterEntity.setId(modelId);
|
||||||
|
datasetEntity.setModel(modelMasterEntity);
|
||||||
|
|
||||||
|
if (addReq.getModelNo().equals(ModelType.M1.getId())) {
|
||||||
|
datasetEntity.setBuildingCnt(dataset.getSummary().getBuildingCnt());
|
||||||
|
datasetEntity.setContainerCnt(dataset.getSummary().getContainerCnt());
|
||||||
|
} else if (addReq.getModelNo().equals(ModelType.M2.getId())) {
|
||||||
|
datasetEntity.setWasteCnt(dataset.getSummary().getWasteCnt());
|
||||||
|
} else if (addReq.getModelNo().equals(ModelType.M3.getId())) {
|
||||||
|
datasetEntity.setLandCoverCnt(dataset.getSummary().getLandCoverCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
datasetEntity.setCreatedUid(userUtil.getId());
|
||||||
|
|
||||||
|
// data set 저장
|
||||||
|
modelDatasetRepository.save(datasetEntity);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델 데이터셋 mapping 테이블 저장
|
||||||
|
*
|
||||||
|
* @param modelId 모델학습 id
|
||||||
|
* @param datasetList 선택한 data set
|
||||||
|
*/
|
||||||
|
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||||
|
|
||||||
|
for (Long datasetId : datasetList) {
|
||||||
|
ModelDatasetMappEntity mapEntity = new ModelDatasetMappEntity();
|
||||||
|
mapEntity.setModelUid(modelId);
|
||||||
|
mapEntity.setDatasetUid(datasetId);
|
||||||
|
modelDatasetMapRepository.save(mapEntity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 모델학습 config 저장
|
||||||
|
*
|
||||||
|
* @param modelId 모델학습 id
|
||||||
|
* @param req 요청 파라미터
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Long saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||||
|
ModelMasterEntity modelMasterEntity = new ModelMasterEntity();
|
||||||
|
ModelConfigEntity entity = new ModelConfigEntity();
|
||||||
|
modelMasterEntity.setId(modelId);
|
||||||
|
entity.setModel(modelMasterEntity);
|
||||||
|
entity.setEpochCount(req.getEpochCnt());
|
||||||
|
entity.setTrainPercent(req.getTrainingCnt());
|
||||||
|
entity.setValidationPercent(req.getValidationCnt());
|
||||||
|
entity.setTestPercent(req.getTestCnt());
|
||||||
|
entity.setMemo(req.getMemo());
|
||||||
|
|
||||||
|
return modelConfigRepository.save(entity).getId();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 모델 상세 조회
|
* 모델 상세 조회
|
||||||
*
|
*
|
||||||
@@ -136,7 +245,7 @@ public class ModelMngCoreService {
|
|||||||
mapping.setModelUid(modelUid);
|
mapping.setModelUid(modelUid);
|
||||||
mapping.setDatasetUid(datasetId);
|
mapping.setDatasetUid(datasetId);
|
||||||
mapping.setDatasetType("TRAIN");
|
mapping.setDatasetType("TRAIN");
|
||||||
modelDatasetMappRepository.save(mapping);
|
modelDatasetMapRepository.save(mapping);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ public class DatasetEntity {
|
|||||||
@JdbcTypeCode(SqlTypes.JSON)
|
@JdbcTypeCode(SqlTypes.JSON)
|
||||||
private Map<String, Integer> classCounts;
|
private Map<String, Integer> classCounts;
|
||||||
|
|
||||||
@Size(max = 255)
|
@Size(max = 32)
|
||||||
@Column(name = "uid")
|
@Column(name = "uid")
|
||||||
private String uid;
|
private String uid;
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import jakarta.persistence.Id;
|
|||||||
import jakarta.persistence.Table;
|
import jakarta.persistence.Table;
|
||||||
import jakarta.validation.constraints.NotNull;
|
import jakarta.validation.constraints.NotNull;
|
||||||
import jakarta.validation.constraints.Size;
|
import jakarta.validation.constraints.Size;
|
||||||
|
import java.math.BigDecimal;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
@@ -79,6 +80,19 @@ public class DatasetObjEntity {
|
|||||||
@Column(name = "uuid")
|
@Column(name = "uuid")
|
||||||
private UUID uuid;
|
private UUID uuid;
|
||||||
|
|
||||||
|
@Size(max = 32)
|
||||||
|
@Column(name = "uid")
|
||||||
|
private String uid;
|
||||||
|
|
||||||
|
@Column(precision = 5, scale = 2)
|
||||||
|
private BigDecimal chnDtctP;
|
||||||
|
|
||||||
|
@Column(precision = 5, scale = 2)
|
||||||
|
private BigDecimal bfClsPro;
|
||||||
|
|
||||||
|
@Column(precision = 5, scale = 2)
|
||||||
|
private BigDecimal afClsPro;
|
||||||
|
|
||||||
public Basic toDto() {
|
public Basic toDto() {
|
||||||
return new DatasetObjDto.Basic(
|
return new DatasetObjDto.Basic(
|
||||||
this.objId,
|
this.objId,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class ModelConfigEntity {
|
|||||||
@Id
|
@Id
|
||||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||||
@Column(name = "config_id", nullable = false)
|
@Column(name = "config_id", nullable = false)
|
||||||
private Integer id;
|
private Long id;
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
@ManyToOne(fetch = FetchType.LAZY, optional = false)
|
@ManyToOne(fetch = FetchType.LAZY, optional = false)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package com.kamco.cd.training.postgres.entity;
|
|||||||
import jakarta.persistence.Column;
|
import jakarta.persistence.Column;
|
||||||
import jakarta.persistence.Entity;
|
import jakarta.persistence.Entity;
|
||||||
import jakarta.persistence.FetchType;
|
import jakarta.persistence.FetchType;
|
||||||
|
import jakarta.persistence.GeneratedValue;
|
||||||
|
import jakarta.persistence.GenerationType;
|
||||||
import jakarta.persistence.Id;
|
import jakarta.persistence.Id;
|
||||||
import jakarta.persistence.JoinColumn;
|
import jakarta.persistence.JoinColumn;
|
||||||
import jakarta.persistence.ManyToOne;
|
import jakarta.persistence.ManyToOne;
|
||||||
@@ -20,6 +22,7 @@ import org.hibernate.annotations.ColumnDefault;
|
|||||||
public class ModelDatasetEntity {
|
public class ModelDatasetEntity {
|
||||||
|
|
||||||
@Id
|
@Id
|
||||||
|
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||||
@Column(name = "id", nullable = false)
|
@Column(name = "id", nullable = false)
|
||||||
private Long id;
|
private Long id;
|
||||||
|
|
||||||
@@ -28,10 +31,6 @@ public class ModelDatasetEntity {
|
|||||||
@JoinColumn(name = "model_id", nullable = false)
|
@JoinColumn(name = "model_id", nullable = false)
|
||||||
private ModelMasterEntity model;
|
private ModelMasterEntity model;
|
||||||
|
|
||||||
@NotNull
|
|
||||||
@Column(name = "data_id", nullable = false)
|
|
||||||
private Long dataId;
|
|
||||||
|
|
||||||
@Column(name = "building_cnt")
|
@Column(name = "building_cnt")
|
||||||
private Long buildingCnt;
|
private Long buildingCnt;
|
||||||
|
|
||||||
@@ -46,7 +45,7 @@ public class ModelDatasetEntity {
|
|||||||
|
|
||||||
@ColumnDefault("now()")
|
@ColumnDefault("now()")
|
||||||
@Column(name = "created_dttm")
|
@Column(name = "created_dttm")
|
||||||
private ZonedDateTime createdDttm;
|
private ZonedDateTime createdDttm = ZonedDateTime.now();
|
||||||
|
|
||||||
@Column(name = "created_uid")
|
@Column(name = "created_uid")
|
||||||
private Long createdUid;
|
private Long createdUid;
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package com.kamco.cd.training.postgres.entity;
|
package com.kamco.cd.training.postgres.entity;
|
||||||
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import jakarta.persistence.Column;
|
import jakarta.persistence.Column;
|
||||||
import jakarta.persistence.Entity;
|
import jakarta.persistence.Entity;
|
||||||
import jakarta.persistence.GeneratedValue;
|
import jakarta.persistence.GeneratedValue;
|
||||||
import jakarta.persistence.GenerationType;
|
import jakarta.persistence.GenerationType;
|
||||||
import jakarta.persistence.Id;
|
import jakarta.persistence.Id;
|
||||||
import jakarta.persistence.Table;
|
import jakarta.persistence.Table;
|
||||||
import jakarta.validation.constraints.NotNull;
|
|
||||||
import jakarta.validation.constraints.Size;
|
import jakarta.validation.constraints.Size;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
@@ -26,8 +25,7 @@ public class ModelMasterEntity {
|
|||||||
@Column(name = "model_id", nullable = false)
|
@Column(name = "model_id", nullable = false)
|
||||||
private Long id;
|
private Long id;
|
||||||
|
|
||||||
@NotNull
|
@Column(name = "hyper_param_id")
|
||||||
@Column(name = "hyper_param_id", nullable = false)
|
|
||||||
private Long hyperParamId;
|
private Long hyperParamId;
|
||||||
|
|
||||||
@Size(max = 10)
|
@Size(max = 10)
|
||||||
@@ -69,7 +67,7 @@ public class ModelMasterEntity {
|
|||||||
private String step2State;
|
private String step2State;
|
||||||
|
|
||||||
@Column(name = "del_yn")
|
@Column(name = "del_yn")
|
||||||
private Boolean delYn;
|
private Boolean delYn = false;
|
||||||
|
|
||||||
@ColumnDefault("now()")
|
@ColumnDefault("now()")
|
||||||
@Column(name = "created_dttm")
|
@Column(name = "created_dttm")
|
||||||
@@ -90,8 +88,8 @@ public class ModelMasterEntity {
|
|||||||
@Column(name = "train_type")
|
@Column(name = "train_type")
|
||||||
private String trainType;
|
private String trainType;
|
||||||
|
|
||||||
public ModelTrainDto.Basic toDto() {
|
public ModelTrainMngDto.Basic toDto() {
|
||||||
return new ModelTrainDto.Basic(
|
return new ModelTrainMngDto.Basic(
|
||||||
this.id,
|
this.id,
|
||||||
this.uuid,
|
this.uuid,
|
||||||
this.modelVer,
|
this.modelVer,
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package com.kamco.cd.training.postgres.repository.model;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
|
||||||
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
|
|
||||||
|
public interface ModelConfigRepository extends JpaRepository<ModelConfigEntity, Long> {}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package com.kamco.cd.training.postgres.repository.model;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
|
||||||
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
|
|
||||||
|
public interface ModelDatasetRepository
|
||||||
|
extends JpaRepository<ModelDatasetEntity, Long>, ModelDatasetRepositoryCustom {}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package com.kamco.cd.training.postgres.repository.model;
|
||||||
|
|
||||||
|
public interface ModelDatasetRepositoryCustom {}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.model;
|
package com.kamco.cd.training.postgres.repository.model;
|
||||||
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
@@ -14,7 +14,7 @@ public interface ModelMngRepositoryCustom {
|
|||||||
* @param searchReq
|
* @param searchReq
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Page<ModelMasterEntity> findByModels(ModelTrainDto.SearchReq searchReq);
|
Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq searchReq);
|
||||||
|
|
||||||
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
|||||||
|
|
||||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||||
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import com.querydsl.core.BooleanBuilder;
|
import com.querydsl.core.BooleanBuilder;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
@@ -28,7 +28,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Page<ModelMasterEntity> findByModels(ModelTrainDto.SearchReq req) {
|
public Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq req) {
|
||||||
Pageable pageable = req.toPageable();
|
Pageable pageable = req.toPageable();
|
||||||
BooleanBuilder builder = new BooleanBuilder();
|
BooleanBuilder builder = new BooleanBuilder();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user