하이퍼파라미터 기능 추가

This commit is contained in:
2026-02-03 14:31:53 +09:00
parent e2757d3ca0
commit 3a8d6e3ef0
18 changed files with 946 additions and 688 deletions

View File

@@ -0,0 +1,303 @@
package com.kamco.cd.training.hyperparam.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
public class HyperParamDto {
@Schema(name = "Basic", description = "하이퍼파라미터 조회")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private Long id;
private UUID uuid;
private String hyperVer;
// -------------------------
// Important
// -------------------------
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// -------------------------
// Model Architecture
// -------------------------
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// -------------------------
// Loss & Optimization
// -------------------------
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// -------------------------
// Data
// -------------------------
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// -------------------------
// Evaluation
// -------------------------
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// -------------------------
// Augmentation
// -------------------------
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// -------------------------
// Hardware
// -------------------------
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// -------------------------
// Common
// -------------------------
private String memo;
private String delYn;
@JsonFormatDttm private ZonedDateTime createdDttm;
private Long createdUid;
@JsonFormatDttm private ZonedDateTime updatedDttm;
private Long updatedUid;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class List {
private UUID uuid;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
private Long totalCnt;
}
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class HyperParamCreateReq {
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone
@Schema(description = "입력 이미지 크기(H,W)", example = "256,256")
private String inputSize; // input_size
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
private String cropSize; // crop_size
@Schema(description = "총 학습 에폭 수", example = "200")
private Integer epochCnt; // epoch_cnt
@Schema(description = "배치 크기(Per GPU)", example = "16")
private Integer batchSize; // batch_size
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
@Schema(description = "Train dataloader workers", example = "16")
private Integer trainNumWorkers; // train_num_workers
@Schema(description = "Val dataloader workers", example = "8")
private Integer valNumWorkers; // val_num_workers
@Schema(description = "Test dataloader workers", example = "8")
private Integer testNumWorkers; // test_num_workers
@Schema(description = "Train shuffle 여부", example = "true")
private Boolean trainShuffle; // train_shuffle
@Schema(description = "Train persistent workers 여부", example = "true")
private Boolean trainPersistent; // train_persistent
@Schema(description = "Val persistent workers 여부", example = "true")
private Boolean valPersistent; // val_persistent
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
private String metrics; // metrics
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
private String saveBest; // save_best
@Schema(description = "Best 모델 선정 규칙", example = "greater")
private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
private Integer valInterval; // val_interval
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
private Integer logInterval; // log_interval
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
private Integer visInterval; // vis_interval
@Schema(description = "회전 적용 확률", example = "0.5")
private Double rotProb; // rot_prob
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
private String rotDegree; // rot_degree
@Schema(description = "채널 교환 확률", example = "0.5")
private Double exchangeProb; // exchange_prob
@Schema(description = "밝기 변화량", example = "10")
private Integer brightnessDelta; // brightness_delta
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
private String contrastRange; // contrast_range
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
private String saturationRange; // saturation_range
@Schema(description = "색조 변화량", example = "10")
private Integer hueDelta; // hue_delta
@Schema(description = "사용 GPU 개수", example = "4")
private Integer gpuCnt; // gpu_cnt
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
private String gpuIds; // gpu_ids
@Schema(description = "분산학습 마스터 포트", example = "1122")
private Integer masterPort; // master_port
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
private String memo; // memo
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class SearchReq {
private String type;
private LocalDate startDate;
private LocalDate endDate;
private String hyperVer;
// 페이징 파라미터
private int page = 0;
private int size = 20;
private String sort;
public Pageable toPageable() {
if (sort != null && !sort.isEmpty()) {
String[] sortParams = sort.split(",");
String property = sortParams[0];
Sort.Direction direction =
sortParams.length > 1 ? Sort.Direction.fromString(sortParams[1]) : Sort.Direction.ASC;
return PageRequest.of(page, size, Sort.by(direction, property));
}
return PageRequest.of(page, size);
}
}
@CodeExpose
@Getter
@AllArgsConstructor
public enum HyperType implements EnumType {
CREATE_DATE("생성일"),
LAST_USED_DATE("최근 사용일");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}
}