하이퍼파라미터 기능 추가

This commit is contained in:
2026-02-03 14:31:53 +09:00
parent e2757d3ca0
commit 3a8d6e3ef0
18 changed files with 946 additions and 688 deletions

0
gradlew vendored Normal file → Executable file
View File

View File

@@ -482,37 +482,40 @@ public class GlobalExceptionHandler {
@ExceptionHandler(CustomApiException.class)
public ApiResponseDto<String> handleCustomApiException(
CustomApiException e, HttpServletRequest request) {
log.warn("[CustomApiException] resource : {}", e.getMessage());
this.errorLog(request, e);
String codeName = e.getCodeName();
HttpStatus status = e.getStatus();
// String message = e.getMessage() == null ? ApiResponseCode.getMessage(codeName) :
// e.getMessage();
//
// ApiResponseCode apiCode = ApiResponseCode.getCode(codeName);
//
// ErrorLogEntity errorLog =
// saveErrorLogData(
// request, apiCode, status, ErrorLogDto.LogErrorLevel.WARNING, e.getStackTrace());
//
// ApiResponseDto<String> body =
// ApiResponseDto.createException(apiCode, message, status, errorLog.getId());
// enum에 없는 codeName이면 IllegalArgumentException 방지
ApiResponseDto.ApiResponseCode apiCode = null;
try {
apiCode = ApiResponseDto.ApiResponseCode.getCode(codeName);
} catch (Exception ignore) {
apiCode = null;
}
// 메시지 우선순위: 예외 message > enum message > fallback
String message;
if (e.getMessage() != null && !e.getMessage().isBlank()) {
message = e.getMessage();
} else if (apiCode != null) {
message = apiCode.getText(); // enum 기본 메시지
} else {
message = "서버에 문제가 발생 하였습니다."; // 최후 fallback
}
// 로그/응답 코드: enum 없으면 기본코드로
ApiResponseDto.ApiResponseCode safeCode =
(apiCode != null) ? apiCode : ApiResponseDto.ApiResponseCode.INTERNAL_SERVER_ERROR;
ErrorLogEntity errorLog =
saveErrorLogData(
request,
ApiResponseCode.getCode(codeName),
HttpStatus.valueOf(status.value()),
ErrorLogDto.LogErrorLevel.WARNING,
e.getStackTrace());
request, safeCode, status, ErrorLogDto.LogErrorLevel.WARNING, e.getStackTrace());
return ApiResponseDto.createException(
ApiResponseCode.getCode(codeName),
ApiResponseCode.getMessage(codeName),
HttpStatus.valueOf(status.value()),
errorLog.getId());
// return new ResponseEntity<>(body, status);
return ApiResponseDto.createException(safeCode, message, status, errorLog.getId());
}
private void errorLog(HttpServletRequest request, Throwable e) {

View File

@@ -189,7 +189,12 @@ public class ApiResponseDto<T> {
}
public static ApiResponseCode getCode(String name) {
if (name == null || name.isBlank()) return null;
try {
return ApiResponseCode.valueOf(name.toUpperCase());
} catch (IllegalArgumentException ex) {
return null;
}
}
public static String getMessage(String name) {

View File

@@ -0,0 +1,141 @@
package com.kamco.cd.training.hyperparam;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
import com.kamco.cd.training.hyperparam.service.HyperParamService;
import com.kamco.cd.training.model.dto.ModelMngDto;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import java.time.LocalDate;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "하이퍼파라미터 관리", description = "하이퍼파라미터 관리 API")
@RestController
@RequiredArgsConstructor
@RequestMapping("/api/hyper-param")
public class HyperParamApiController {
private final HyperParamService hyperParamService;
@Operation(summary = "하이퍼파라미터 등록", description = "파라미터를 신규 저장")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "등록 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping
public ApiResponseDto<String> createHyperParam(
@Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) {
String newVersion = hyperParamService.createHyperParam(createReq);
return ApiResponseDto.ok(newVersion);
}
@Operation(summary = "하이퍼파라미터 수정", description = "파라미터를 수정")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "등록 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PutMapping("/{uuid}")
public ApiResponseDto<String> updateHyperParam(
@PathVariable UUID uuid, @Valid @RequestBody HyperParamDto.HyperParamCreateReq createReq) {
return ApiResponseDto.ok(hyperParamService.updateHyperParam(uuid, createReq));
}
@Operation(summary = "하이퍼파라미터 목록 조회", description = "하이퍼파라미터 목록 조회")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/list")
public ApiResponseDto<Page<List>> getHyperParam(
@Parameter(description = "구분") @RequestParam(required = false) String type,
@Parameter(description = "시작일") @RequestParam(required = false) LocalDate startDate,
@Parameter(description = "종료일") @RequestParam(required = false) LocalDate endDate,
@Parameter(description = "버전명") @RequestParam(required = false) String hyperVer,
@Parameter(
description = "정렬",
example = "createdDttm desc",
schema =
@Schema(
allowableValues = {
"createdDttm,desc",
"lastUsedDttm,desc",
"totalUseCnt,desc"
}))
@RequestParam(required = false)
String sort,
@Parameter(description = "페이지 번호 (0부터 시작)", example = "0") @RequestParam(defaultValue = "0")
int page,
@Parameter(description = "페이지 크기", example = "20") @RequestParam(defaultValue = "20")
int size) {
HyperParamDto.SearchReq searchReq = new HyperParamDto.SearchReq();
searchReq.setType(type);
searchReq.setStartDate(startDate);
searchReq.setEndDate(endDate);
searchReq.setHyperVer(hyperVer);
searchReq.setSort(sort);
searchReq.setPage(page);
searchReq.setSize(size);
Page<List> list = hyperParamService.getHyperParamList(searchReq);
return ApiResponseDto.ok(list);
}
@Operation(summary = "하이퍼파라미터 삭제", description = "하이퍼파라미터 삭제")
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "7966dd64-004a-4596-89ef-001664bc4de2")
@PathVariable
UUID uuid) {
hyperParamService.deleteHyperParam(uuid);
return ApiResponseDto.ok(null);
}
}

View File

@@ -0,0 +1,303 @@
package com.kamco.cd.training.hyperparam.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
public class HyperParamDto {
@Schema(name = "Basic", description = "하이퍼파라미터 조회")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private Long id;
private UUID uuid;
private String hyperVer;
// -------------------------
// Important
// -------------------------
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// -------------------------
// Model Architecture
// -------------------------
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// -------------------------
// Loss & Optimization
// -------------------------
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// -------------------------
// Data
// -------------------------
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// -------------------------
// Evaluation
// -------------------------
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// -------------------------
// Augmentation
// -------------------------
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// -------------------------
// Hardware
// -------------------------
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// -------------------------
// Common
// -------------------------
private String memo;
private String delYn;
@JsonFormatDttm private ZonedDateTime createdDttm;
private Long createdUid;
@JsonFormatDttm private ZonedDateTime updatedDttm;
private Long updatedUid;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class List {
private UUID uuid;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
private Long totalCnt;
}
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class HyperParamCreateReq {
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone
@Schema(description = "입력 이미지 크기(H,W)", example = "256,256")
private String inputSize; // input_size
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
private String cropSize; // crop_size
@Schema(description = "총 학습 에폭 수", example = "200")
private Integer epochCnt; // epoch_cnt
@Schema(description = "배치 크기(Per GPU)", example = "16")
private Integer batchSize; // batch_size
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
@Schema(description = "Train dataloader workers", example = "16")
private Integer trainNumWorkers; // train_num_workers
@Schema(description = "Val dataloader workers", example = "8")
private Integer valNumWorkers; // val_num_workers
@Schema(description = "Test dataloader workers", example = "8")
private Integer testNumWorkers; // test_num_workers
@Schema(description = "Train shuffle 여부", example = "true")
private Boolean trainShuffle; // train_shuffle
@Schema(description = "Train persistent workers 여부", example = "true")
private Boolean trainPersistent; // train_persistent
@Schema(description = "Val persistent workers 여부", example = "true")
private Boolean valPersistent; // val_persistent
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
private String metrics; // metrics
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
private String saveBest; // save_best
@Schema(description = "Best 모델 선정 규칙", example = "greater")
private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
private Integer valInterval; // val_interval
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
private Integer logInterval; // log_interval
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
private Integer visInterval; // vis_interval
@Schema(description = "회전 적용 확률", example = "0.5")
private Double rotProb; // rot_prob
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
private String rotDegree; // rot_degree
@Schema(description = "채널 교환 확률", example = "0.5")
private Double exchangeProb; // exchange_prob
@Schema(description = "밝기 변화량", example = "10")
private Integer brightnessDelta; // brightness_delta
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
private String contrastRange; // contrast_range
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
private String saturationRange; // saturation_range
@Schema(description = "색조 변화량", example = "10")
private Integer hueDelta; // hue_delta
@Schema(description = "사용 GPU 개수", example = "4")
private Integer gpuCnt; // gpu_cnt
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
private String gpuIds; // gpu_ids
@Schema(description = "분산학습 마스터 포트", example = "1122")
private Integer masterPort; // master_port
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
private String memo; // memo
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class SearchReq {
private String type;
private LocalDate startDate;
private LocalDate endDate;
private String hyperVer;
// 페이징 파라미터
private int page = 0;
private int size = 20;
private String sort;
public Pageable toPageable() {
if (sort != null && !sort.isEmpty()) {
String[] sortParams = sort.split(",");
String property = sortParams[0];
Sort.Direction direction =
sortParams.length > 1 ? Sort.Direction.fromString(sortParams[1]) : Sort.Direction.ASC;
return PageRequest.of(page, size, Sort.by(direction, property));
}
return PageRequest.of(page, size);
}
}
@CodeExpose
@Getter
@AllArgsConstructor
public enum HyperType implements EnumType {
CREATE_DATE("생성일"),
LAST_USED_DATE("최근 사용일");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}
}

View File

@@ -0,0 +1,59 @@
package com.kamco.cd.training.hyperparam.service;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@Transactional(readOnly = true)
@RequiredArgsConstructor
public class HyperParamService {
private final HyperParamCoreService hyperParamCoreService;
/**
* 하이퍼 파라미터 목록 조회
*
* @param req
* @return 목록
*/
public Page<List> getHyperParamList(HyperParamDto.SearchReq req) {
return hyperParamCoreService.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 등록
*
* @param createReq 등록 요청
* @return 생성된 버전명
*/
@Transactional
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
return hyperParamCoreService.createHyperParam(createReq);
}
/**
* 하이퍼파라미터 수정
*
* @param createReq
* @return
*/
@Transactional
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
return hyperParamCoreService.updateHyperParam(uuid, createReq);
}
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
hyperParamCoreService.deleteHyperParam(uuid);
}
}

View File

@@ -11,10 +11,10 @@ import lombok.ToString;
@ToString(exclude = "password")
public class SignInRequest {
@Schema(description = "사용자 ID", example = "1234567")
@Schema(description = "사용자 ID", example = "123456")
private String username;
@Schema(description = "비밀번호", example = "Admin2!@#")
@Schema(description = "비밀번호", example = "qwe123!@#")
@JsonProperty(access = JsonProperty.Access.WRITE_ONLY)
private String password;
}

View File

@@ -112,26 +112,6 @@ public class ModelMngApiController {
return ApiResponseDto.ok(modelTrainService.getFormConfig());
}
@Operation(summary = "하이퍼파라미터 등록", description = "Step 1 에서 파라미터를 수정하여 신규 버전으로 저장합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "등록 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/hyper-params")
public ApiResponseDto<String> createHyperParam(
@Valid @RequestBody ModelMngDto.HyperParamCreateReq createReq) {
String newVersion = modelTrainService.createHyperParam(createReq);
return ApiResponseDto.ok(newVersion);
}
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
@ApiResponses(
value = {
@@ -151,22 +131,6 @@ public class ModelMngApiController {
return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer));
}
@Operation(summary = "하이퍼파라미터 삭제", description = "특정 버전의 하이퍼파라미터를 삭제합니다 (H1은 삭제 불가)")
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/hyper-params/{hyperVer}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") @PathVariable
String hyperVer) {
modelTrainService.deleteHyperParam(hyperVer);
return ApiResponseDto.ok(null);
}
@Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다")
@ApiResponses(
value = {

View File

@@ -1,218 +0,0 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import java.time.ZonedDateTime;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class HyperParamDto {
@Schema(name = "HyperParam Basic", description = "하이퍼파라미터 기본 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
@JsonFormatDttm private ZonedDateTime createdDttm;
public Basic(ModelHyperParamEntity entity) {
this.hyperVer = entity.getHyperVer();
// Important
this.backbone = entity.getBackbone();
this.inputSize = entity.getInputSize();
this.cropSize = entity.getCropSize();
this.epochCnt = entity.getEpochCnt();
this.batchSize = entity.getBatchSize();
// Architecture
this.dropPathRate = entity.getDropPathRate();
this.frozenStages = entity.getFrozenStages();
this.neckPolicy = entity.getNeckPolicy();
this.decoderChannels = entity.getDecoderChannels();
this.classWeight = entity.getClassWeight();
this.numLayers = entity.getNumLayers();
// Optimization
this.learningRate = entity.getLearningRate();
this.weightDecay = entity.getWeightDecay();
this.layerDecayRate = entity.getLayerDecayRate();
this.ddpFindUnusedParams = entity.getDdpFindUnusedParams();
this.ignoreIndex = entity.getIgnoreIndex();
// Data
this.trainNumWorkers = entity.getTrainNumWorkers();
this.valNumWorkers = entity.getValNumWorkers();
this.testNumWorkers = entity.getTestNumWorkers();
this.trainShuffle = entity.getTrainShuffle();
this.trainPersistent = entity.getTrainPersistent();
this.valPersistent = entity.getValPersistent();
// Evaluation
this.metrics = entity.getMetrics();
this.saveBest = entity.getSaveBest();
this.saveBestRule = entity.getSaveBestRule();
this.valInterval = entity.getValInterval();
this.logInterval = entity.getLogInterval();
this.visInterval = entity.getVisInterval();
// Hardware
this.gpuCnt = entity.getGpuCnt();
this.gpuIds = entity.getGpuIds();
this.masterPort = entity.getMasterPort();
// Augmentation
this.rotProb = entity.getRotProb();
this.flipProb = entity.getFlipProb();
this.rotDegree = entity.getRotDegree();
this.exchangeProb = entity.getExchangeProb();
this.brightnessDelta = entity.getBrightnessDelta();
this.contrastRange = entity.getContrastRange();
this.saturationRange = entity.getSaturationRange();
this.hueDelta = entity.getHueDelta();
// Legacy
this.cnnFilterCnt = entity.getCnnFilterCnt();
// Common
this.memo = entity.getMemo();
this.createdDttm = entity.getCreatedDttm();
}
}
@Schema(name = "HyperParam AddReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class AddReq {
@NotBlank(message = "버전명은 필수입니다")
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
}
}

View File

@@ -5,7 +5,6 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.LocalDate;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
@@ -303,155 +302,6 @@ public class ModelMngDto {
@JsonFormatDttm private ZonedDateTime createdDttm;
}
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class HyperParamCreateReq {
@Schema(description = "새로운 하이파라미터 버전", example = "")
private String newHyperVer;
@Schema(description = "불러온 하이파라미터 버전", example = "")
private String baseHyperVer;
@Schema(description = "하이퍼파라미터 PK", example = "1")
private Long hyperParamId; // hyper_param_id
@Schema(description = "하이퍼파라미터 UUID", example = "3fa85f64-5717-4562-b3fc-2c963f66afa6")
private UUID uuid; // uuid (또는 hyper_param_uuid 컬럼이면 이름 맞춰 주세요)
@Schema(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518")
private String hyperVer; // hyper_ver
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone
@Schema(description = "입력 이미지 크기(H,W)", example = "256,256")
private String inputSize; // input_size
@Schema(description = "크롭 크기(H,W 또는 단일값)", example = "256,256")
private String cropSize; // crop_size
@Schema(description = "총 학습 에폭 수", example = "200")
private Integer epochCnt; // epoch_cnt
@Schema(description = "배치 크기(Per GPU)", example = "16")
private Integer batchSize; // batch_size
@Schema(description = "Drop Path 비율", example = "0.3")
private Double dropPathRate; // drop_path_rate
@Schema(description = "Freeze 단계(-1:None)", example = "-1")
private Integer frozenStages; // frozen_stages
@Schema(description = "Neck 결합 정책", example = "abs_diff")
private String neckPolicy; // neck_policy
@Schema(description = "디코더 채널 구성", example = "512,256,128,64")
private String decoderChannels; // decoder_channels
@Schema(description = "클래스별 가중치", example = "1,10")
private String classWeight; // class_weight
@Schema(description = "레이어 깊이", example = "24")
private Integer numLayers; // num_layers
@Schema(description = "학습률", example = "0.00006")
private Double learningRate; // learning_rate
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay; // weight_decay
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate; // layer_decay_rate
@Schema(description = "DDP unused params 탐색 여부", example = "true")
private Boolean ddpFindUnusedParams; // ddp_find_unused_params
@Schema(description = "Loss 계산 제외 인덱스", example = "255")
private Integer ignoreIndex; // ignore_index
@Schema(description = "Train dataloader workers", example = "16")
private Integer trainNumWorkers; // train_num_workers
@Schema(description = "Val dataloader workers", example = "8")
private Integer valNumWorkers; // val_num_workers
@Schema(description = "Test dataloader workers", example = "8")
private Integer testNumWorkers; // test_num_workers
@Schema(description = "Train shuffle 여부", example = "true")
private Boolean trainShuffle; // train_shuffle
@Schema(description = "Train persistent workers 여부", example = "true")
private Boolean trainPersistent; // train_persistent
@Schema(description = "Val persistent workers 여부", example = "true")
private Boolean valPersistent; // val_persistent
@Schema(description = "평가 지표 목록", example = "mFscore,mIoU")
private String metrics; // metrics
@Schema(description = "Best 모델 선정 기준 지표", example = "changed_fscore")
private String saveBest; // save_best
@Schema(description = "Best 모델 선정 규칙", example = "greater")
private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10")
private Integer valInterval; // val_interval
@Schema(description = "로그 기록 주기(Iteration)", example = "400")
private Integer logInterval; // log_interval
@Schema(description = "시각화 저장 주기(Epoch)", example = "1")
private Integer visInterval; // vis_interval
@Schema(description = "회전 적용 확률", example = "0.5")
private Double rotProb; // rot_prob
@Schema(description = "반전 적용 확률", example = "0.5")
private Double flipProb; // flip_prob
@Schema(description = "회전 각도 범위(Min,Max)", example = "-20,20")
private String rotDegree; // rot_degree
@Schema(description = "채널 교환 확률", example = "0.5")
private Double exchangeProb; // exchange_prob
@Schema(description = "밝기 변화량", example = "10")
private Integer brightnessDelta; // brightness_delta
@Schema(description = "대비 범위(Min,Max)", example = "0.8,1.2")
private String contrastRange; // contrast_range
@Schema(description = "채도 범위(Min,Max)", example = "0.8,1.2")
private String saturationRange; // saturation_range
@Schema(description = "색조 변화량", example = "10")
private Integer hueDelta; // hue_delta
@Schema(description = "사용 GPU 개수", example = "4")
private Integer gpuCnt; // gpu_cnt
@Schema(description = "사용 GPU ID 목록", example = "0,1,2,3")
private String gpuIds; // gpu_ids
@Schema(description = "분산학습 마스터 포트", example = "1122")
private Integer masterPort; // master_port
@Schema(description = "메모", example = "하이퍼파라미터 신규등록")
private String memo; // memo
@Schema(description = "삭제 여부(Y/N)", example = "N")
private String delYn; // del_yn
@Schema(description = "생성 일시")
private LocalDate createdDttm; // created_dttm
}
@Schema(name = "TrainStartReq", description = "학습 시작 요청")
@Getter
@Setter

View File

@@ -52,7 +52,8 @@ public class ModelTrainService {
}
// 3. 하이퍼파라미터 목록
List<ModelMngDto.HyperParamInfo> hyperParams = hyperParamCoreService.findAllActiveHyperParams();
List<ModelMngDto.HyperParamInfo> hyperParams =
null; // hyperParamCoreService.findAllActiveHyperParams();
// 4. 데이터셋 목록
List<ModelMngDto.DatasetInfo> datasets = datasetCoreService.findAllActiveDatasetsForTraining();
@@ -65,26 +66,6 @@ public class ModelTrainService {
.build();
}
/**
* 하이퍼파라미터 등록
*
* @param createReq 등록 요청
* @return 생성된 버전명
*/
@Transactional
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
// 신규 버전 추가 시 baseHyperVer가 없으면 H1으로 설정
if (createReq.getBaseHyperVer() == null || createReq.getBaseHyperVer().isEmpty()) {
String firstVersion = hyperParamCoreService.getFirstHyperParamVersion();
createReq.setBaseHyperVer(firstVersion);
log.info("baseHyperVer가 없어 첫 번째 버전으로 설정: {}", firstVersion);
}
String newVersion = hyperParamCoreService.createHyperParam(createReq);
log.info("하이퍼파라미터 등록 완료: {}", newVersion);
return newVersion;
}
/**
* 하이퍼파라미터 단건 조회
*
@@ -245,10 +226,10 @@ public class ModelTrainService {
.lastEpoch(entity.getLastCheckpointEpoch())
.totalEpoch(entity.getEpochCnt())
.checkpointPath(entity.getCheckpointPath())
.failedAt(
entity.getStopDttm() != null
? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
: null)
// .failedAt(
// entity.getStopDttm() != null
// ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
// : null)
.build();
}

View File

@@ -1,31 +1,24 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.kamco.cd.training.postgres.repository.model.ModelHyperParamRepository;
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
public class HyperParamCoreService {
private final ModelHyperParamRepository hyperParamRepository;
/**
* 하이퍼파라미터 전체 조회 (삭제되지 않은 것만)
*
* @return 하이퍼파라미터 목록
*/
public List<ModelMngDto.HyperParamInfo> findAllActiveHyperParams() {
List<ModelHyperParamEntity> entities =
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
return entities.stream().map(this::mapToHyperParamInfo).toList();
}
private final HyperParamRepository hyperParamRepository;
private final UserUtil userUtil;
private ModelMngDto.HyperParamInfo mapToHyperParamInfo(ModelHyperParamEntity entity) {
return ModelMngDto.HyperParamInfo.builder()
@@ -76,8 +69,6 @@ public class HyperParamCoreService {
.contrastRange(entity.getContrastRange())
.saturationRange(entity.getSaturationRange())
.hueDelta(entity.getHueDelta())
// Legacy
.cnnFilterCnt(entity.getCnnFilterCnt())
// Common
.memo(entity.getMemo())
.createdDttm(entity.getCreatedDttm())
@@ -90,168 +81,104 @@ public class HyperParamCoreService {
* @param createReq 등록 요청
* @return 등록된 버전명
*/
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
// 중복 체크
if (hyperParamRepository.existsById(createReq.getNewHyperVer())) {
throw new BadRequestException("이미 존재하는 버전입니다: " + createReq.getNewHyperVer());
public String createHyperParam(HyperParamDto.HyperParamCreateReq createReq) {
String firstVersion = getFirstHyperParamVersion();
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq);
// user
entity.setCreatedUid(userUtil.getId());
ModelHyperParamEntity resultEntity = hyperParamRepository.save(entity);
return resultEntity.getHyperVer();
}
// 기준 버전 조회
ModelHyperParamEntity baseEntity =
/**
* 하이퍼파라미터 수정
*
* @param uuid uuid
* @param createReq 등록 요청
* @return ver
*/
public String updateHyperParam(UUID uuid, HyperParamDto.HyperParamCreateReq createReq) {
ModelHyperParamEntity entity =
hyperParamRepository
.findById(createReq.getBaseHyperVer())
.orElseThrow(
() -> new NotFoundException("기준 버전을 찾을 수 없습니다: " + createReq.getBaseHyperVer()));
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// 신규 엔티티 생성 (기준 값 복사 후 변경된 값만 적용)
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(createReq.getNewHyperVer());
applyHyperParam(entity, createReq);
// user
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
return entity.getHyperVer();
}
private void applyHyperParam(
ModelHyperParamEntity entity, HyperParamDto.HyperParamCreateReq createReq) {
// Important
entity.setBackbone(
createReq.getBackbone() != null ? createReq.getBackbone() : baseEntity.getBackbone());
entity.setInputSize(
createReq.getInputSize() != null ? createReq.getInputSize() : baseEntity.getInputSize());
entity.setCropSize(
createReq.getCropSize() != null ? createReq.getCropSize() : baseEntity.getCropSize());
entity.setEpochCnt(
createReq.getEpochCnt() != null ? createReq.getEpochCnt() : baseEntity.getEpochCnt());
entity.setBatchSize(
createReq.getBatchSize() != null ? createReq.getBatchSize() : baseEntity.getBatchSize());
// Architecture
entity.setDropPathRate(
createReq.getDropPathRate() != null
? createReq.getDropPathRate()
: baseEntity.getDropPathRate());
entity.setFrozenStages(
createReq.getFrozenStages() != null
? createReq.getFrozenStages()
: baseEntity.getFrozenStages());
entity.setNeckPolicy(
createReq.getNeckPolicy() != null ? createReq.getNeckPolicy() : baseEntity.getNeckPolicy());
entity.setDecoderChannels(
createReq.getDecoderChannels() != null
? createReq.getDecoderChannels()
: baseEntity.getDecoderChannels());
entity.setClassWeight(
createReq.getClassWeight() != null
? createReq.getClassWeight()
: baseEntity.getClassWeight());
entity.setNumLayers(
createReq.getNumLayers() != null ? createReq.getNumLayers() : baseEntity.getNumLayers());
// Optimization
entity.setLearningRate(
createReq.getLearningRate() != null
? createReq.getLearningRate()
: baseEntity.getLearningRate());
entity.setWeightDecay(
createReq.getWeightDecay() != null
? createReq.getWeightDecay()
: baseEntity.getWeightDecay());
entity.setLayerDecayRate(
createReq.getLayerDecayRate() != null
? createReq.getLayerDecayRate()
: baseEntity.getLayerDecayRate());
entity.setDdpFindUnusedParams(
createReq.getDdpFindUnusedParams() != null
? createReq.getDdpFindUnusedParams()
: baseEntity.getDdpFindUnusedParams());
entity.setIgnoreIndex(
createReq.getIgnoreIndex() != null
? createReq.getIgnoreIndex()
: baseEntity.getIgnoreIndex());
entity.setBackbone(createReq.getBackbone());
entity.setInputSize(createReq.getInputSize());
entity.setCropSize(createReq.getCropSize());
entity.setBatchSize(createReq.getBatchSize());
// Data
entity.setTrainNumWorkers(
createReq.getTrainNumWorkers() != null
? createReq.getTrainNumWorkers()
: baseEntity.getTrainNumWorkers());
entity.setValNumWorkers(
createReq.getValNumWorkers() != null
? createReq.getValNumWorkers()
: baseEntity.getValNumWorkers());
entity.setTestNumWorkers(
createReq.getTestNumWorkers() != null
? createReq.getTestNumWorkers()
: baseEntity.getTestNumWorkers());
entity.setTrainShuffle(
createReq.getTrainShuffle() != null
? createReq.getTrainShuffle()
: baseEntity.getTrainShuffle());
entity.setTrainPersistent(
createReq.getTrainPersistent() != null
? createReq.getTrainPersistent()
: baseEntity.getTrainPersistent());
entity.setValPersistent(
createReq.getValPersistent() != null
? createReq.getValPersistent()
: baseEntity.getValPersistent());
entity.setTrainNumWorkers(createReq.getTrainNumWorkers());
entity.setValNumWorkers(createReq.getValNumWorkers());
entity.setTestNumWorkers(createReq.getTestNumWorkers());
entity.setTrainShuffle(createReq.getTrainShuffle());
entity.setTrainPersistent(createReq.getTrainPersistent());
entity.setValPersistent(createReq.getValPersistent());
// Model Architecture
entity.setDropPathRate(createReq.getDropPathRate());
entity.setFrozenStages(createReq.getFrozenStages());
entity.setNeckPolicy(createReq.getNeckPolicy());
entity.setClassWeight(createReq.getClassWeight());
entity.setDecoderChannels(createReq.getDecoderChannels());
// Loss & Optimization
entity.setLearningRate(createReq.getLearningRate());
entity.setWeightDecay(createReq.getWeightDecay());
entity.setLayerDecayRate(createReq.getLayerDecayRate());
entity.setDdpFindUnusedParams(createReq.getDdpFindUnusedParams());
entity.setIgnoreIndex(createReq.getIgnoreIndex());
entity.setNumLayers(createReq.getNumLayers());
// Evaluation
entity.setMetrics(
createReq.getMetrics() != null ? createReq.getMetrics() : baseEntity.getMetrics());
entity.setSaveBest(
createReq.getSaveBest() != null ? createReq.getSaveBest() : baseEntity.getSaveBest());
entity.setSaveBestRule(
createReq.getSaveBestRule() != null
? createReq.getSaveBestRule()
: baseEntity.getSaveBestRule());
entity.setValInterval(
createReq.getValInterval() != null
? createReq.getValInterval()
: baseEntity.getValInterval());
entity.setLogInterval(
createReq.getLogInterval() != null
? createReq.getLogInterval()
: baseEntity.getLogInterval());
entity.setVisInterval(
createReq.getVisInterval() != null
? createReq.getVisInterval()
: baseEntity.getVisInterval());
entity.setMetrics(createReq.getMetrics());
entity.setSaveBest(createReq.getSaveBest());
entity.setSaveBestRule(createReq.getSaveBestRule());
entity.setValInterval(createReq.getValInterval());
entity.setLogInterval(createReq.getLogInterval());
entity.setVisInterval(createReq.getVisInterval());
// Hardware
entity.setGpuCnt(
createReq.getGpuCnt() != null ? createReq.getGpuCnt() : baseEntity.getGpuCnt());
entity.setGpuIds(
createReq.getGpuIds() != null ? createReq.getGpuIds() : baseEntity.getGpuIds());
entity.setMasterPort(
createReq.getMasterPort() != null ? createReq.getMasterPort() : baseEntity.getMasterPort());
// Augmentation
entity.setRotProb(
createReq.getRotProb() != null ? createReq.getRotProb() : baseEntity.getRotProb());
entity.setFlipProb(
createReq.getFlipProb() != null ? createReq.getFlipProb() : baseEntity.getFlipProb());
entity.setRotDegree(
createReq.getRotDegree() != null ? createReq.getRotDegree() : baseEntity.getRotDegree());
entity.setExchangeProb(
createReq.getExchangeProb() != null
? createReq.getExchangeProb()
: baseEntity.getExchangeProb());
entity.setBrightnessDelta(
createReq.getBrightnessDelta() != null
? createReq.getBrightnessDelta()
: baseEntity.getBrightnessDelta());
entity.setContrastRange(
createReq.getContrastRange() != null
? createReq.getContrastRange()
: baseEntity.getContrastRange());
entity.setSaturationRange(
createReq.getSaturationRange() != null
? createReq.getSaturationRange()
: baseEntity.getSaturationRange());
entity.setHueDelta(
createReq.getHueDelta() != null ? createReq.getHueDelta() : baseEntity.getHueDelta());
// Common
// memo
entity.setMemo(createReq.getMemo());
entity.setDelYn("N");
entity.setCreatedDttm(ZonedDateTime.now());
}
ModelHyperParamEntity saved = hyperParamRepository.save(entity);
return saved.getHyperVer();
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
throw new CustomApiException("CONFLICT", HttpStatus.CONFLICT, "HPs_0001 버전은 삭제할수 없습니다.");
}
entity.setDelYn("Y");
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
@@ -261,27 +188,16 @@ public class HyperParamCoreService {
* @return 하이퍼파라미터 정보
*/
public ModelMngDto.HyperParamInfo findByHyperVer(String hyperVer) {
ModelHyperParamEntity entity =
hyperParamRepository
.findById(hyperVer)
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
// ModelHyperParamEntity entity =
// hyperParamRepository
// .findById(hyperVer)
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
//
// if ("Y".equals(entity.getDelYn())) {
// throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
// }
if ("Y".equals(entity.getDelYn())) {
throw new NotFoundException("삭제된 하이퍼파라미터입니다: " + hyperVer);
}
return mapToHyperParamInfo(entity);
}
/**
* 하이퍼파라미터 수정 (기존 버전은 수정 불가)
*
* @param hyperVer 하이퍼파라미터 버전
* @param updateReq 수정 요청
*/
public void updateHyperParam(String hyperVer, ModelMngDto.HyperParamCreateReq updateReq) {
// 기존 버전은 수정 불가
throw new BadRequestException("기존 버전은 수정할 수 없습니다. 신규 버전을 생성해주세요.");
return mapToHyperParamInfo(null);
}
/**
@@ -294,33 +210,47 @@ public class HyperParamCoreService {
if ("H1".equals(hyperVer)) {
throw new BadRequestException("H1은 디폴트 하이퍼파라미터 버전이므로 삭제할 수 없습니다.");
}
ModelHyperParamEntity entity =
hyperParamRepository
.findById(hyperVer)
.orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
if ("Y".equals(entity.getDelYn())) {
throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
}
// 논리 삭제 처리
entity.setDelYn("Y");
hyperParamRepository.save(entity);
//
// ModelHyperParamEntity entity =
// hyperParamRepository
// .findById(hyperVer)
// .orElseThrow(() -> new NotFoundException("하이퍼파라미터를 찾을 수 없습니다: " + hyperVer));
//
// if ("Y".equals(entity.getDelYn())) {
// throw new BadRequestException("이미 삭제된 하이퍼파라미터입니다: " + hyperVer);
// }
//
// // 논리 삭제 처리
// entity.setDelYn("Y");
// hyperParamRepository.save(entity);
}
/**
* 첫 번째 하이퍼파라미터 버전 조회 (H1 확인용)
* 하이퍼파라미터 목록 조회
*
* @return 첫 번째 하이퍼파라미터 버전
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
return hyperParamRepository.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @return ver
*/
public String getFirstHyperParamVersion() {
List<ModelHyperParamEntity> entities =
hyperParamRepository.findByDelYnOrderByCreatedDttmDesc("N");
if (entities.isEmpty()) {
throw new NotFoundException("하이퍼파라미터가 존재하지 않습니다.");
return hyperParamRepository
.findHyperParamVer()
.map(ModelHyperParamEntity::getHyperVer)
.map(this::increase)
.orElse("HPs_0001");
}
// 가장 오래된 것이 H1이므로 리스트의 마지막 요소 반환
return entities.get(entities.size() - 1).getHyperVer();
private String increase(String hyperVer) {
String prefix = "HPs_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%04d", num + 1);
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import jakarta.persistence.*;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
@@ -28,7 +29,7 @@ public class ModelHyperParamEntity {
@NotNull
@UuidGenerator
@Column(name = "uuid", nullable = false, updatable = false)
private UUID uuid;
private UUID uuid = UUID.randomUUID();
@Size(max = 50)
@NotNull
@@ -259,8 +260,7 @@ public class ModelHyperParamEntity {
// -------------------------
/** Default: 4 */
@NotNull
@Column(name = "gpu_cnt", nullable = false)
@Column(name = "gpu_cnt")
private Integer gpuCnt;
/** Default: 0,1,2,3 */
@@ -289,9 +289,86 @@ public class ModelHyperParamEntity {
@Column(name = "created_dttm", nullable = false)
private ZonedDateTime createdDttm = ZonedDateTime.now();
@Column(name = "cnn_filter_cnt")
private Integer cnnFilterCnt;
@NotNull
@Column(name = "created_uid", nullable = false)
private Long createdUid;
@ColumnDefault("CURRENT_TIMESTAMP")
@Column(name = "updated_dttm")
private ZonedDateTime updatedDttm;
@Column(name = "updated_uid")
private Long updatedUid;
@ColumnDefault("CURRENT_TIMESTAMP")
@Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt")
private Long m1UseCnt = 0L;
@Column(name = "m2_use_cnt")
private Long m2UseCnt = 0L;
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@OneToMany(mappedBy = "hyperParams", fetch = FetchType.LAZY)
private Set<ModelTrainMasterEntity> trainMasters = new LinkedHashSet<>();
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.id,
this.uuid,
this.hyperVer,
this.backbone,
this.inputSize,
this.cropSize,
this.epochCnt,
this.batchSize,
this.dropPathRate,
this.frozenStages,
this.neckPolicy,
this.decoderChannels,
this.classWeight,
this.numLayers,
this.learningRate,
this.weightDecay,
this.layerDecayRate,
this.ddpFindUnusedParams,
this.ignoreIndex,
this.trainNumWorkers,
this.valNumWorkers,
this.testNumWorkers,
this.trainShuffle,
this.trainPersistent,
this.valPersistent,
this.metrics,
this.saveBest,
this.saveBestRule,
this.valInterval,
this.logInterval,
this.visInterval,
this.rotProb,
this.flipProb,
this.rotDegree,
this.exchangeProb,
this.brightnessDelta,
this.contrastRange,
this.saturationRange,
this.hueDelta,
this.gpuCnt,
this.gpuIds,
this.masterPort,
this.memo,
this.delYn,
this.createdDttm,
this.createdUid,
this.updatedDttm,
this.updatedUid,
this.lastUsedDttm,
this.m1UseCnt,
this.m2UseCnt,
this.m3UseCnt);
}
}

View File

@@ -12,7 +12,6 @@ import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.Getter;
@@ -112,7 +111,7 @@ public class ModelTrainMasterEntity {
private Integer progressRate;
@Column(name = "stop_dttm")
private Instant stopDttm;
private ZonedDateTime stopDttm;
@Column(name = "confirmed_best_epoch")
private Integer confirmedBestEpoch;
@@ -125,7 +124,7 @@ public class ModelTrainMasterEntity {
private String errorMsg;
@Column(name = "step2_start_dttm")
private Instant step2StartDttm;
private ZonedDateTime step2StartDttm;
@Size(max = 1000)
@Column(name = "train_log_path", length = 1000)

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface HyperParamRepository
extends JpaRepository<ModelHyperParamEntity, Long>, HyperParamRepositoryCustom {}

View File

@@ -0,0 +1,21 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
public interface HyperParamRepositoryCustom {
/**
* 마지막 버전 조회
*
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamVer();
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
}

View File

@@ -0,0 +1,148 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
private final JPAQueryFactory queryFactory;
@Override
public Optional<ModelHyperParamEntity> findHyperParamVer() {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.eq("N"))
.orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1)
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.eq("N").and(modelHyperParamEntity.uuid.eq(uuid)))
.fetchOne());
}
@Override
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.eq("N"));
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전
builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer()));
}
if (req.getStartDate() != null && req.getEndDate() != null) {
ZoneId zoneId = ZoneId.systemDefault();
ZonedDateTime start = req.getStartDate().atStartOfDay(zoneId);
ZonedDateTime end = req.getEndDate().atTime(23, 59, 59).atZone(zoneId);
if (HyperType.CREATE_DATE.getId().equals(req.getType())) {
// 생성일
builder.and(modelHyperParamEntity.createdDttm.between(start, end));
} else if (HyperType.LAST_USED_DATE.getId().equals(req.getType())) {
// 최종 사용일
builder.and(modelHyperParamEntity.lastUsedDttm.between(start, end));
}
}
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query =
queryFactory
.select(
Projections.constructor(
HyperParamDto.List.class,
modelHyperParamEntity.uuid,
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m2UseCnt,
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
.from(modelHyperParamEntity)
.where(builder);
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
if (sortOrder == null) {
// 기본값
query.orderBy(modelHyperParamEntity.createdDttm.desc());
} else {
String property = sortOrder.getProperty();
boolean asc = sortOrder.isAscending();
switch (property) {
case "createdDttm" ->
query.orderBy(
asc
? modelHyperParamEntity.createdDttm.asc()
: modelHyperParamEntity.createdDttm.desc());
case "lastUsedDttm" ->
query.orderBy(
asc
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}
}
List<HyperParamDto.List> content =
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
Long total =
queryFactory
.select(modelHyperParamEntity.count())
.from(modelHyperParamEntity)
.where(builder)
.fetchOne();
long totalCount = (total != null) ? total : 0L;
return new PageImpl<>(content, pageable, totalCount);
}
}

View File

@@ -1,14 +0,0 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelHyperParamRepository extends JpaRepository<ModelHyperParamEntity, String> {
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmDesc(String delYn);
List<ModelHyperParamEntity> findByDelYnOrderByCreatedDttmAsc(String delYn);
}