하이퍼파라미터 기능 추가

This commit is contained in:
2026-02-03 14:31:53 +09:00
parent e2757d3ca0
commit 3a8d6e3ef0
18 changed files with 946 additions and 688 deletions

View File

@@ -112,26 +112,6 @@ public class ModelMngApiController {
return ApiResponseDto.ok(modelTrainService.getFormConfig());
}
@Operation(summary = "하이퍼파라미터 등록", description = "Step 1 에서 파라미터를 수정하여 신규 버전으로 저장합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "등록 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/hyper-params")
public ApiResponseDto<String> createHyperParam(
@Valid @RequestBody ModelMngDto.HyperParamCreateReq createReq) {
String newVersion = modelTrainService.createHyperParam(createReq);
return ApiResponseDto.ok(newVersion);
}
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
@ApiResponses(
value = {
@@ -151,22 +131,6 @@ public class ModelMngApiController {
return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer));
}
@Operation(summary = "하이퍼파라미터 삭제", description = "특정 버전의 하이퍼파라미터를 삭제합니다 (H1은 삭제 불가)")
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/hyper-params/{hyperVer}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") @PathVariable
String hyperVer) {
modelTrainService.deleteHyperParam(hyperVer);
return ApiResponseDto.ok(null);
}
@Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다")
@ApiResponses(
value = {

View File

@@ -1,218 +0,0 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import java.time.ZonedDateTime;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class HyperParamDto {
@Schema(name = "HyperParam Basic", description = "하이퍼파라미터 기본 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
@JsonFormatDttm private ZonedDateTime createdDttm;
public Basic(ModelHyperParamEntity entity) {
this.hyperVer = entity.getHyperVer();
// Important
this.backbone = entity.getBackbone();
this.inputSize = entity.getInputSize();
this.cropSize = entity.getCropSize();
this.epochCnt = entity.getEpochCnt();
this.batchSize = entity.getBatchSize();
// Architecture
this.dropPathRate = entity.getDropPathRate();
this.frozenStages = entity.getFrozenStages();
this.neckPolicy = entity.getNeckPolicy();
this.decoderChannels = entity.getDecoderChannels();
this.classWeight = entity.getClassWeight();
this.numLayers = entity.getNumLayers();
// Optimization
this.learningRate = entity.getLearningRate();
this.weightDecay = entity.getWeightDecay();
this.layerDecayRate = entity.getLayerDecayRate();
this.ddpFindUnusedParams = entity.getDdpFindUnusedParams();
this.ignoreIndex = entity.getIgnoreIndex();
// Data
this.trainNumWorkers = entity.getTrainNumWorkers();
this.valNumWorkers = entity.getValNumWorkers();
this.testNumWorkers = entity.getTestNumWorkers();
this.trainShuffle = entity.getTrainShuffle();
this.trainPersistent = entity.getTrainPersistent();
this.valPersistent = entity.getValPersistent();
// Evaluation
this.metrics = entity.getMetrics();
this.saveBest = entity.getSaveBest();
this.saveBestRule = entity.getSaveBestRule();
this.valInterval = entity.getValInterval();
this.logInterval = entity.getLogInterval();
this.visInterval = entity.getVisInterval();
// Hardware
this.gpuCnt = entity.getGpuCnt();
this.gpuIds = entity.getGpuIds();
this.masterPort = entity.getMasterPort();
// Augmentation
this.rotProb = entity.getRotProb();
this.flipProb = entity.getFlipProb();
this.rotDegree = entity.getRotDegree();
this.exchangeProb = entity.getExchangeProb();
this.brightnessDelta = entity.getBrightnessDelta();
this.contrastRange = entity.getContrastRange();
this.saturationRange = entity.getSaturationRange();
this.hueDelta = entity.getHueDelta();
// Legacy
this.cnnFilterCnt = entity.getCnnFilterCnt();
// Common
this.memo = entity.getMemo();
this.createdDttm = entity.getCreatedDttm();
}
}
@Schema(name = "HyperParam AddReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class AddReq {
@NotBlank(message = "버전명은 필수입니다")
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
}
}

View File

@@ -5,7 +5,6 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.LocalDate;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
@@ -303,155 +302,6 @@ public class ModelMngDto {
@JsonFormatDttm private ZonedDateTime createdDttm;
}
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class HyperParamCreateReq {
@Schema(description = "새로운 하이파라미터 버전", example = "")
private String newHyperVer;
@Schema(description = "불러온 하이파라미터 버전", example = "")
private String baseHyperVer;
@Schema(description = "하이퍼파라미터 PK", example = "1")
private Long hyperParamId; // hyper_param_id
@Schema(description = "하이퍼파라미터 UUID", example = "3fa85f64-5717-4562-b3fc-2c963f66afa6")
private UUID uuid; // uuid (또는 hyper_param_uuid 컬럼이면 이름 맞춰 주세요)
@Schema(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518")
private String hyperVer; // hyper_ver
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone
@Schema(description = "입력 이미지 크기(H,W)", example = "256,256")
private String inputSize; // input_size
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
private String cropSize; // crop_size
@Schema(description = "총 학습 에폭 수", example = "200")
private Integer epochCnt; // epoch_cnt
@Schema(description = "배치 크기(Per GPU)", example = "16")
private Integer batchSize; // batch_size
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
@Schema(description = "Train dataloader workers", example = "16")
private Integer trainNumWorkers; // train_num_workers
@Schema(description = "Val dataloader workers", example = "8")
private Integer valNumWorkers; // val_num_workers
@Schema(description = "Test dataloader workers", example = "8")
private Integer testNumWorkers; // test_num_workers
@Schema(description = "Train shuffle 여부", example = "true")
private Boolean trainShuffle; // train_shuffle
@Schema(description = "Train persistent workers 여부", example = "true")
private Boolean trainPersistent; // train_persistent
@Schema(description = "Val persistent workers 여부", example = "true")
private Boolean valPersistent; // val_persistent
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
private String metrics; // metrics
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
private String saveBest; // save_best
@Schema(description = "Best 모델 선정 규칙", example = "greater")
private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
private Integer valInterval; // val_interval
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
private Integer logInterval; // log_interval
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
private Integer visInterval; // vis_interval
@Schema(description = "회전 적용 확률", example = "0.5")
private Double rotProb; // rot_prob
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
private String rotDegree; // rot_degree
@Schema(description = "채널 교환 확률", example = "0.5")
private Double exchangeProb; // exchange_prob
@Schema(description = "밝기 변화량", example = "10")
private Integer brightnessDelta; // brightness_delta
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
private String contrastRange; // contrast_range
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
private String saturationRange; // saturation_range
@Schema(description = "색조 변화량", example = "10")
private Integer hueDelta; // hue_delta
@Schema(description = "사용 GPU 개수", example = "4")
private Integer gpuCnt; // gpu_cnt
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
private String gpuIds; // gpu_ids
@Schema(description = "분산학습 마스터 포트", example = "1122")
private Integer masterPort; // master_port
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
private String memo; // memo
@Schema(description = "삭제 여부(Y/N)", example = "N")
private String delYn; // del_yn
@Schema(description = "생성 일시")
private LocalDate createdDttm; // created_dttm
}
@Schema(name = "TrainStartReq", description = "학습 시작 요청")
@Getter
@Setter

View File

@@ -52,7 +52,8 @@ public class ModelTrainService {
}
// 3. 하이퍼파라미터 목록
List<ModelMngDto.HyperParamInfo> hyperParams = hyperParamCoreService.findAllActiveHyperParams();
List<ModelMngDto.HyperParamInfo> hyperParams =
null; // hyperParamCoreService.findAllActiveHyperParams();
// 4. 데이터셋 목록
List<ModelMngDto.DatasetInfo> datasets = datasetCoreService.findAllActiveDatasetsForTraining();
@@ -65,26 +66,6 @@ public class ModelTrainService {
.build();
}
/**
* 하이퍼파라미터 등록
*
* @param createReq 등록 요청
* @return 생성된 버전명
*/
@Transactional
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
// 신규 버전 추가 시 baseHyperVer가 없으면 H1으로 설정
if (createReq.getBaseHyperVer() == null || createReq.getBaseHyperVer().isEmpty()) {
String firstVersion = hyperParamCoreService.getFirstHyperParamVersion();
createReq.setBaseHyperVer(firstVersion);
log.info("baseHyperVer가 없어 첫 번째 버전으로 설정: {}", firstVersion);
}
String newVersion = hyperParamCoreService.createHyperParam(createReq);
log.info("하이퍼파라미터 등록 완료: {}", newVersion);
return newVersion;
}
/**
* 하이퍼파라미터 단건 조회
*
@@ -245,10 +226,10 @@ public class ModelTrainService {
.lastEpoch(entity.getLastCheckpointEpoch())
.totalEpoch(entity.getEpochCnt())
.checkpointPath(entity.getCheckpointPath())
.failedAt(
entity.getStopDttm() != null
? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
: null)
// .failedAt(
// entity.getStopDttm() != null
// ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
// : null)
.build();
}