feat/training_260202 #7

Merged
teddy merged 2 commits from feat/training_260202 into develop 2026-02-03 15:09:29 +09:00
6 changed files with 213 additions and 128 deletions
Showing only changes of commit 335e9d33d6 - Show all commits

View File

@@ -83,7 +83,7 @@ public class HyperParamApiController {
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
schema = @Schema(implementation = Page.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@@ -133,9 +133,8 @@ public class HyperParamApiController {
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "H1 삭제 불가", content = @Content),
@ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteHyperParam(
@@ -145,4 +144,43 @@ public class HyperParamApiController {
hyperParamService.deleteHyperParam(uuid);
return ApiResponseDto.ok(null);
}
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/{uuid}")
public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "9c91a20c-71e7-4e5f-a860-9626d2b2059c")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));
}
@Operation(summary = "하이퍼파라미터 최적화 값 조회", description = "하이퍼파라미터 최적화 값 조회 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Page.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/init")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam() {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam());
}
}

View File

@@ -23,9 +23,8 @@ public class HyperParamDto {
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private Long id;
private UUID uuid;
private String hyperVer;
// -------------------------
// Important
@@ -33,28 +32,8 @@ public class HyperParamDto {
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// -------------------------
// Model Architecture
// -------------------------
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// -------------------------
// Loss & Optimization
// -------------------------
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// -------------------------
// Data
// -------------------------
@@ -65,6 +44,25 @@ public class HyperParamDto {
private Boolean trainPersistent;
private Boolean valPersistent;
// -------------------------
// Model Architecture
// -------------------------
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
// -------------------------
// Loss & Optimization
// -------------------------
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
private Integer numLayers;
// -------------------------
// Evaluation
// -------------------------
@@ -79,36 +77,25 @@ public class HyperParamDto {
// Augmentation
// -------------------------
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double flipProb;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// -------------------------
// Memo
// -------------------------
private String memo;
// -------------------------
// Hardware
// -------------------------
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// -------------------------
// Common
// -------------------------
private String memo;
private String delYn;
@JsonFormatDttm private ZonedDateTime createdDttm;
private Long createdUid;
@JsonFormatDttm private ZonedDateTime updatedDttm;
private Long updatedUid;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
}
@Getter
@@ -133,54 +120,24 @@ public class HyperParamDto {
@AllArgsConstructor
public static class HyperParamCreateReq {
// -------------------------
// Important
// -------------------------
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone
@Schema(description = "입력 이미지 크기(H,W)", example = "256,256")
@Schema(description = "입력 이미지 크기(H,W)", example = "512,512")
private String inputSize; // input_size
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
private String cropSize; // crop_size
@Schema(description = "총 학습 에폭 수", example = "200")
private Integer epochCnt; // epoch_cnt
@Schema(description = "배치 크기(Per GPU)", example = "16")
private Integer batchSize; // batch_size
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
// -------------------------
// Data
// -------------------------
@Schema(description = "Train dataloader workers", example = "16")
private Integer trainNumWorkers; // train_num_workers
@@ -199,13 +156,55 @@ public class HyperParamDto {
@Schema(description = "Val persistent workers 여부", example = "true")
private Boolean valPersistent; // val_persistent
// -------------------------
// Model Architecture
// -------------------------
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
// -------------------------
// Loss & Optimization
// -------------------------
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
// -------------------------
// Evaluation
// -------------------------
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
private String metrics; // metrics
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
private String saveBest; // save_best
@Schema(description = "Best 모델 선정 규칙", example = "greater")
@Schema(description = "Best 모델 선정 규칙", example = "less")
private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
@@ -217,15 +216,18 @@ public class HyperParamDto {
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
private Integer visInterval; // vis_interval
// -------------------------
// Augmentation
// -------------------------
@Schema(description = "회전 적용 확률", example = "0.5")
private Double rotProb; // rot_prob
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
private String rotDegree; // rot_degree
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "채널 교환 확률", example = "0.5")
private Double exchangeProb; // exchange_prob
@@ -241,6 +243,9 @@ public class HyperParamDto {
@Schema(description = "색조 변화량", example = "10")
private Integer hueDelta; // hue_delta
// -------------------------
// Hardware
// -------------------------
@Schema(description = "사용 GPU 개수", example = "4")
private Integer gpuCnt; // gpu_cnt
@@ -250,6 +255,9 @@ public class HyperParamDto {
@Schema(description = "분산학습 마스터 포트", example = "1122")
private Integer masterPort; // master_port
// -------------------------
// Memo
// -------------------------
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
private String memo; // memo
}

View File

@@ -56,4 +56,19 @@ public class HyperParamService {
public void deleteHyperParam(UUID uuid) {
hyperParamCoreService.deleteHyperParam(uuid);
}
/** 하이퍼파라미터 최적화 설정값 조회 */
public HyperParamDto.Basic getInitHyperParam() {
return hyperParamCoreService.getInitHyperParam();
}
/**
* 하이퍼파라미터 단건 조회
*
* @param uuid
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
return hyperParamCoreService.getHyperParam(uuid);
}
}

View File

@@ -112,25 +112,6 @@ public class ModelMngApiController {
return ApiResponseDto.ok(modelTrainService.getFormConfig());
}
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/hyper-params/{hyperVer}")
public ApiResponseDto<ModelMngDto.HyperParamInfo> getHyperParam(
@Parameter(description = "하이퍼파라미터 버전", example = "H1") @PathVariable String hyperVer) {
return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer));
}
@Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다")
@ApiResponses(
value = {

View File

@@ -181,6 +181,29 @@ public class HyperParamCoreService {
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam() {
ModelHyperParamEntity entity = new ModelHyperParamEntity();
return entity.toDto();
}
/**
* 하이퍼파라미터 단건 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 단건 조회
*

View File

@@ -318,57 +318,77 @@ public class ModelHyperParamEntity {
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.id,
this.uuid,
this.hyperVer,
// -------------------------
// Important
// -------------------------
this.backbone,
this.inputSize,
this.cropSize,
this.epochCnt,
this.batchSize,
this.dropPathRate,
this.frozenStages,
this.neckPolicy,
this.decoderChannels,
this.classWeight,
this.numLayers,
this.learningRate,
this.weightDecay,
this.layerDecayRate,
this.ddpFindUnusedParams,
this.ignoreIndex,
// -------------------------
// Data
// -------------------------
this.trainNumWorkers,
this.valNumWorkers,
this.testNumWorkers,
this.trainShuffle,
this.trainPersistent,
this.valPersistent,
// -------------------------
// Model Architecture
// -------------------------
this.dropPathRate,
this.frozenStages,
this.neckPolicy,
this.decoderChannels,
this.classWeight,
// -------------------------
// Loss & Optimization
// -------------------------
this.learningRate,
this.weightDecay,
this.layerDecayRate,
this.ddpFindUnusedParams,
this.ignoreIndex,
this.numLayers,
// -------------------------
// Evaluation
// -------------------------
this.metrics,
this.saveBest,
this.saveBestRule,
this.valInterval,
this.logInterval,
this.visInterval,
// -------------------------
// Augmentation
// -------------------------
this.rotProb,
this.flipProb,
this.rotDegree,
this.flipProb,
this.exchangeProb,
this.brightnessDelta,
this.contrastRange,
this.saturationRange,
this.hueDelta,
// -------------------------
// Memo
// -------------------------
this.memo,
// -------------------------
// Hardware
// -------------------------
this.gpuCnt,
this.gpuIds,
this.masterPort,
this.memo,
this.delYn,
this.createdDttm,
this.createdUid,
this.updatedDttm,
this.updatedUid,
this.lastUsedDttm,
this.m1UseCnt,
this.m2UseCnt,
this.m3UseCnt);
this.masterPort);
}
}