This commit is contained in:
2026-02-02 12:29:00 +09:00
parent f8f5cef6e1
commit 495ef7d86c
175 changed files with 45128 additions and 0 deletions

View File

@@ -0,0 +1,286 @@
package com.kamco.cd.training.model;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.model.dto.ModelMngDto.Basic;
import com.kamco.cd.training.model.service.ModelMngService;
import com.kamco.cd.training.model.service.ModelTrainService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequiredArgsConstructor
@Tag(name = "모델관리", description = "모델관리 (학습 모델, 하이퍼파라미터, 메모)")
@RequestMapping("/api/models")
public class ModelMngApiController {
private final ModelMngService modelMngService;
private final ModelTrainService modelTrainService;
@Operation(summary = "학습 모델 목록 조회", description = "학습 모델 목록을 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Page.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping
public ApiResponseDto<Page<Basic>> findByModels(
@Parameter(description = "상태 코드") @RequestParam(required = false) String status,
@Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page,
@Parameter(description = "페이지 크기") @RequestParam(defaultValue = "20") int size) {
ModelMngDto.SearchReq searchReq = new ModelMngDto.SearchReq(status, page, size);
return ApiResponseDto.ok(modelMngService.findByModels(searchReq));
}
@Operation(summary = "학습 모델 상세 조회", description = "학습 모델의 상세 정보를 UUID로 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.Detail.class))),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/{uuid}")
public ApiResponseDto<ModelMngDto.Detail> getModelDetail(
@Parameter(description = "모델 UUID", example = "b7e99739-6736-45f9-a224-8161ecddf287")
@PathVariable
String uuid) {
return ApiResponseDto.ok(modelMngService.getModelDetailByUuid(uuid));
}
// ==================== 학습 모델학습관리 API (5종) ====================
@Operation(summary = "학습 모델 통합 조회", description = "학습 관리 화면에서 학습 이력 리스트와 현재 상태를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = List.class))),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/train")
public ApiResponseDto<List<ModelMngDto.TrainListRes>> getTrainModelList() {
return ApiResponseDto.ok(modelTrainService.getTrainModelList());
}
@Operation(summary = "학습 설정 통합 조회", description = "학습 실행 팝업 구성에 필요한 모든 데이터를 한 번에 반환합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.FormConfigRes.class))),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/train/form-config")
public ApiResponseDto<ModelMngDto.FormConfigRes> getFormConfig() {
return ApiResponseDto.ok(modelTrainService.getFormConfig());
}
@Operation(summary = "하이퍼파라미터 등록", description = "Step 1 에서 파라미터를 수정하여 신규 버전으로 저장합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "등록 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/hyper-params")
public ApiResponseDto<String> createHyperParam(
@Valid @RequestBody ModelMngDto.HyperParamCreateReq createReq) {
String newVersion = modelTrainService.createHyperParam(createReq);
return ApiResponseDto.ok(newVersion);
}
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/hyper-params/{hyperVer}")
public ApiResponseDto<ModelMngDto.HyperParamInfo> getHyperParam(
@Parameter(description = "하이퍼파라미터 버전", example = "H1") @PathVariable String hyperVer) {
return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer));
}
@Operation(summary = "하이퍼파라미터 삭제", description = "특정 버전의 하이퍼파라미터를 삭제합니다 (H1은 삭제 불가)")
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/hyper-params/{hyperVer}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") @PathVariable
String hyperVer) {
modelTrainService.deleteHyperParam(hyperVer);
return ApiResponseDto.ok(null);
}
@Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "학습 시작 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.TrainStartRes.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/train")
public ApiResponseDto<ModelMngDto.TrainStartRes> startTraining(
@Valid @RequestBody ModelMngDto.TrainStartReq trainReq) {
return ApiResponseDto.ok(modelTrainService.startTraining(trainReq));
}
@Operation(summary = "학습 모델 삭제", description = "목록에서 특정 학습 모델을 삭제합니다")
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "400", description = "진행 중인 모델은 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/train/{uuid}")
public ApiResponseDto<Void> deleteTrainModel(
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
modelTrainService.deleteTrainModel(uuid);
return ApiResponseDto.ok(null);
}
// ==================== Resume Training (학습 재시작) ====================
@Operation(summary = "학습 재시작 정보 조회", description = "중단된 학습의 재시작 가능 여부와 Checkpoint 정보를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.ResumeInfo.class))),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/train/{uuid}/resume-info")
public ApiResponseDto<ModelMngDto.ResumeInfo> getResumeInfo(
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
return ApiResponseDto.ok(modelTrainService.getResumeInfo(uuid));
}
@Operation(summary = "학습 재시작", description = "중단된 지점(Checkpoint)부터 학습을 재개합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "재시작 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.ResumeResponse.class))),
@ApiResponse(responseCode = "400", description = "재시작 불가능한 상태", content = @Content),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/train/{uuid}/resume")
public ApiResponseDto<ModelMngDto.ResumeResponse> resumeTraining(
@Parameter(description = "모델 UUID") @PathVariable String uuid,
@Valid @RequestBody ModelMngDto.ResumeRequest resumeReq) {
return ApiResponseDto.ok(modelTrainService.resumeTraining(uuid, resumeReq));
}
// ==================== Best Epoch Setting (Best Epoch 설정) ====================
@Operation(summary = "Best Epoch 설정", description = "사용자가 직접 Best Epoch를 선택하여 설정합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "설정 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelMngDto.BestEpochResponse.class))),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/train/{uuid}/best-epoch")
public ApiResponseDto<ModelMngDto.BestEpochResponse> setBestEpoch(
@Parameter(description = "모델 UUID") @PathVariable String uuid,
@Valid @RequestBody ModelMngDto.BestEpochRequest bestEpochReq) {
return ApiResponseDto.ok(modelTrainService.setBestEpoch(uuid, bestEpochReq));
}
@Operation(summary = "Epoch별 성능 지표 조회", description = "학습된 모델의 Epoch별 성능 지표를 조회합니다")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = List.class))),
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/train/{uuid}/epoch-metrics")
public ApiResponseDto<List<ModelMngDto.EpochMetric>> getEpochMetrics(
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
return ApiResponseDto.ok(modelTrainService.getEpochMetrics(uuid));
}
}

View File

@@ -0,0 +1,219 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import java.time.ZonedDateTime;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class HyperParamDto {
@Schema(name = "HyperParam Basic", description = "하이퍼파라미터 기본 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
@JsonFormatDttm private ZonedDateTime createdDttm;
public Basic(ModelHyperParamEntity entity) {
this.hyperVer = entity.getHyperVer();
// Important
this.backbone = entity.getBackbone();
this.inputSize = entity.getInputSize();
this.cropSize = entity.getCropSize();
this.epochCnt = entity.getEpochCnt();
this.batchSize = entity.getBatchSize();
// Architecture
this.dropPathRate = entity.getDropPathRate();
this.frozenStages = entity.getFrozenStages();
this.neckPolicy = entity.getNeckPolicy();
this.decoderChannels = entity.getDecoderChannels();
this.classWeight = entity.getClassWeight();
this.numLayers = entity.getNumLayers();
// Optimization
this.learningRate = entity.getLearningRate();
this.weightDecay = entity.getWeightDecay();
this.layerDecayRate = entity.getLayerDecayRate();
this.ddpFindUnusedParams = entity.getDdpFindUnusedParams();
this.ignoreIndex = entity.getIgnoreIndex();
// Data
this.trainNumWorkers = entity.getTrainNumWorkers();
this.valNumWorkers = entity.getValNumWorkers();
this.testNumWorkers = entity.getTestNumWorkers();
this.trainShuffle = entity.getTrainShuffle();
this.trainPersistent = entity.getTrainPersistent();
this.valPersistent = entity.getValPersistent();
// Evaluation
this.metrics = entity.getMetrics();
this.saveBest = entity.getSaveBest();
this.saveBestRule = entity.getSaveBestRule();
this.valInterval = entity.getValInterval();
this.logInterval = entity.getLogInterval();
this.visInterval = entity.getVisInterval();
// Hardware
this.gpuCnt = entity.getGpuCnt();
this.gpuIds = entity.getGpuIds();
this.masterPort = entity.getMasterPort();
// Augmentation
this.rotProb = entity.getRotProb();
this.flipProb = entity.getFlipProb();
this.rotDegree = entity.getRotDegree();
this.exchangeProb = entity.getExchangeProb();
this.brightnessDelta = entity.getBrightnessDelta();
this.contrastRange = entity.getContrastRange();
this.saturationRange = entity.getSaturationRange();
this.hueDelta = entity.getHueDelta();
// Legacy
this.dropoutRatio = entity.getDropoutRatio();
this.cnnFilterCnt = entity.getCnnFilterCnt();
// Common
this.memo = entity.getMemo();
this.createdDttm = entity.getCreatedDttm();
}
}
@Schema(name = "HyperParam AddReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class AddReq {
@NotBlank(message = "버전명은 필수입니다")
private String hyperVer;
// Important
private String backbone;
private String inputSize;
private String cropSize;
private Integer epochCnt;
private Integer batchSize;
// Architecture
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String decoderChannels;
private String classWeight;
private Integer numLayers;
// Optimization
private Double learningRate;
private Double weightDecay;
private Double layerDecayRate;
private Boolean ddpFindUnusedParams;
private Integer ignoreIndex;
// Data
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// Evaluation
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// Hardware
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
// Augmentation
private Double rotProb;
private Double flipProb;
private String rotDegree;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// Legacy (deprecated)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
private String memo;
}
}

View File

@@ -0,0 +1,595 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
public class ModelMngDto {
@Schema(name = "모델관리 목록 조회", description = "모델관리 목록 조회")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class Basic {
private Long id;
private String modelNm;
@JsonFormatDttm private ZonedDateTime startDttm;
@JsonFormatDttm private ZonedDateTime trainingEndDttm;
@JsonFormatDttm private ZonedDateTime testEndDttm;
private String durationDttm;
private String processStage;
private String statusCd;
private String status;
}
@Schema(name = "searchReq", description = "모델 관리 목록조회 파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class SearchReq {
private String status;
// 페이징 파라미터
private int page = 0;
private int size = 20;
public Pageable toPageable() {
return PageRequest.of(page, size);
}
}
@Schema(name = "Detail", description = "모델 상세 정보")
@Getter
@Builder
public static class Detail {
private String uuid;
private String modelVer;
private String hyperVer;
private String epochVer;
private String processStep;
private String statusCd;
private String statusText;
@JsonFormatDttm private ZonedDateTime trainStartDttm;
private Integer epochCnt;
private String datasetRatio;
private Integer bestEpoch;
private Integer confirmedBestEpoch;
@JsonFormatDttm private ZonedDateTime step1EndDttm;
private String step1Duration;
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String step2Duration;
private Integer progressRate;
@JsonFormatDttm private ZonedDateTime createdDttm;
@JsonFormatDttm private ZonedDateTime updatedDttm;
private String modelPath;
private String errorMsg;
}
@Schema(name = "TrainListRes", description = "학습 모델 목록 응답")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class TrainListRes {
private String uuid;
private String modelVer;
private String status;
private String processStep;
@JsonFormatDttm private ZonedDateTime trainStartDttm;
private Integer progressRate;
private Integer epochCnt;
@JsonFormatDttm private ZonedDateTime step1EndDttm;
private String step1Duration;
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String step2Duration;
@JsonFormatDttm private ZonedDateTime createdDttm;
private String errorMsg;
private Boolean canResume;
private Integer lastCheckpointEpoch;
}
@Schema(name = "FormConfigRes", description = "학습 설정 통합 조회 응답")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class FormConfigRes {
private Boolean isTrainAvailable;
private List<HyperParamInfo> hyperParams;
private List<DatasetInfo> datasets;
private String runningModelUuid;
}
@Schema(name = "HyperParamInfo", description = "하이퍼파라미터 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class HyperParamInfo {
@Schema(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518")
private String hyperVer;
// Important
@Schema(description = "백본", example = "large")
private String backbone;
@Schema(description = "입력 사이즈", example = "256,256")
private String inputSize;
@Schema(description = "크롭 사이즈", example = "256,256")
private String cropSize;
@Schema(description = "에폭 수", example = "200")
private Integer epochCnt;
@Schema(description = "배치 사이즈", example = "16")
private Integer batchSize;
// Architecture
@Schema(description = "Drop Path Rate", example = "0.3")
private Double dropPathRate;
@Schema(description = "Frozen Stages", example = "-1")
private Integer frozenStages;
@Schema(description = "Neck Policy", example = "abs_diff")
private String neckPolicy;
@Schema(description = "Decoder Channels", example = "512,256,128,64")
private String decoderChannels;
@Schema(description = "Class Weight", example = "1,1")
private String classWeight;
@Schema(description = "레이어 수", example = "24")
private Integer numLayers;
// Optimization
@Schema(description = "Learning Rate", example = "0.00006")
private Double learningRate;
@Schema(description = "Weight Decay", example = "0.05")
private Double weightDecay;
@Schema(description = "Layer Decay Rate", example = "0.9")
private Double layerDecayRate;
@Schema(description = "DDP Unused Params 찾기", example = "true")
private Boolean ddpFindUnusedParams;
@Schema(description = "Ignore Index", example = "255")
private Integer ignoreIndex;
// Data
@Schema(description = "Train Workers", example = "16")
private Integer trainNumWorkers;
@Schema(description = "Val Workers", example = "8")
private Integer valNumWorkers;
@Schema(description = "Test Workers", example = "8")
private Integer testNumWorkers;
@Schema(description = "Train Shuffle", example = "true")
private Boolean trainShuffle;
@Schema(description = "Train Persistent", example = "true")
private Boolean trainPersistent;
@Schema(description = "Val Persistent", example = "true")
private Boolean valPersistent;
// Evaluation
@Schema(description = "Metrics", example = "mFscore,mIoU")
private String metrics;
@Schema(description = "Save Best", example = "changed_fscore")
private String saveBest;
@Schema(description = "Save Best Rule", example = "greater")
private String saveBestRule;
@Schema(description = "Val Interval", example = "10")
private Integer valInterval;
@Schema(description = "Log Interval", example = "400")
private Integer logInterval;
@Schema(description = "Vis Interval", example = "1")
private Integer visInterval;
// Hardware
@Schema(description = "GPU 수", example = "4")
private Integer gpuCnt;
@Schema(description = "GPU IDs", example = "0,1,2,3")
private String gpuIds;
@Schema(description = "Master Port", example = "1122")
private Integer masterPort;
// Augmentation
@Schema(description = "Rotation 확률", example = "0.5")
private Double rotProb;
@Schema(description = "Flip 확률", example = "0.5")
private Double flipProb;
@Schema(description = "Rotation 각도", example = "-20,20")
private String rotDegree;
@Schema(description = "Exchange 확률", example = "0.5")
private Double exchangeProb;
@Schema(description = "Brightness Delta", example = "10")
private Integer brightnessDelta;
@Schema(description = "Contrast Range", example = "0.8,1.2")
private String contrastRange;
@Schema(description = "Saturation Range", example = "0.8,1.2")
private String saturationRange;
@Schema(description = "Hue Delta", example = "10")
private Integer hueDelta;
// Legacy
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common
@Schema(description = "메모", example = "안녕하세요 캠코담당자 입니다. 하이퍼파라미터 신규등록합니다")
private String memo;
@JsonFormatDttm private ZonedDateTime createdDttm;
}
@Schema(name = "DatasetInfo", description = "데이터셋 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class DatasetInfo {
private Long id;
private String title;
private String groupTitle;
private Long totalItems;
private String totalSize;
private Map<String, Integer> classCounts;
private String memo;
@JsonFormatDttm private ZonedDateTime createdDttm;
}
@Schema(name = "HyperParamCreateReq", description = "하이퍼파라미터 등록 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class HyperParamCreateReq {
// baseHyperVer는 필수 아님 (신규 생성 시 H1으로 자동 설정)
@Schema(description = "기준이 되는 하이퍼파라미터 버전", example = "H3")
private String baseHyperVer;
@NotBlank(message = "신규 버전명은 필수입니다")
@Schema(description = "새로 생성할 하이퍼파라미터 버전명", example = "V3.99.251221.120518")
private String newHyperVer;
// Important - 필수 필드
@NotBlank(message = "Backbone은 필수입니다")
@Schema(example = "large")
private String backbone;
@NotBlank(message = "Input Size는 필수입니다")
@Schema(example = "256,256")
private String inputSize;
@NotBlank(message = "Crop Size는 필수입니다")
@Schema(example = "256,256")
private String cropSize;
@NotNull(message = "Epoch Count는 필수입니다")
@Schema(example = "200")
private Integer epochCnt;
@NotNull(message = "Batch Size는 필수입니다")
@Schema(example = "16")
private Integer batchSize;
// Architecture - 필수 필드
@NotNull(message = "Drop Path Rate는 필수입니다")
@Schema(example = "0.3")
private Double dropPathRate;
@NotNull(message = "Frozen Stages는 필수입니다")
@Schema(example = "-1")
private Integer frozenStages;
@NotBlank(message = "Neck Policy는 필수입니다")
@Schema(example = "abs_diff")
private String neckPolicy;
@NotBlank(message = "Decoder Channels는 필수입니다")
@Schema(example = "512,256,128,64")
private String decoderChannels;
@NotBlank(message = "Class Weight는 필수입니다")
@Schema(example = "1,1")
private String classWeight;
// numLayers는 필수 아님
@Schema(example = "24")
private Integer numLayers;
// Optimization - 필수 필드
@NotNull(message = "Learning Rate는 필수입니다")
@Schema(example = "0.00006")
private Double learningRate;
@NotNull(message = "Weight Decay는 필수입니다")
@Schema(example = "0.05")
private Double weightDecay;
@NotNull(message = "Layer Decay Rate는 필수입니다")
@Schema(example = "0.9")
private Double layerDecayRate;
@NotNull(message = "DDP Find Unused Params는 필수입니다")
@Schema(example = "true")
private Boolean ddpFindUnusedParams;
@NotNull(message = "Ignore Index는 필수입니다")
@Schema(example = "255")
private Integer ignoreIndex;
// Data - 필수 필드
@NotNull(message = "Train Num Workers는 필수입니다")
@Schema(example = "16")
private Integer trainNumWorkers;
@NotNull(message = "Val Num Workers는 필수입니다")
@Schema(example = "8")
private Integer valNumWorkers;
@NotNull(message = "Test Num Workers는 필수입니다")
@Schema(example = "8")
private Integer testNumWorkers;
@NotNull(message = "Train Shuffle는 필수입니다")
@Schema(example = "true")
private Boolean trainShuffle;
@NotNull(message = "Train Persistent는 필수입니다")
@Schema(example = "true")
private Boolean trainPersistent;
@NotNull(message = "Val Persistent는 필수입니다")
@Schema(example = "true")
private Boolean valPersistent;
// Evaluation - 필수 필드
@NotBlank(message = "Metrics는 필수입니다")
@Schema(example = "mFscore,mIoU")
private String metrics;
@NotBlank(message = "Save Best는 필수입니다")
@Schema(example = "changed_fscore")
private String saveBest;
@NotBlank(message = "Save Best Rule은 필수입니다")
@Schema(example = "greater")
private String saveBestRule;
@NotNull(message = "Val Interval은 필수입니다")
@Schema(example = "10")
private Integer valInterval;
@NotNull(message = "Log Interval은 필수입니다")
@Schema(example = "400")
private Integer logInterval;
@NotNull(message = "Vis Interval은 필수입니다")
@Schema(example = "1")
private Integer visInterval;
// Hardware - 필수 아님 (예외 항목)
@Schema(example = "4")
private Integer gpuCnt;
@Schema(example = "0,1,2,3")
private String gpuIds;
@Schema(example = "1122")
private Integer masterPort;
// Augmentation - 필수 필드
@NotNull(message = "Rotation Probability는 필수입니다")
@Schema(example = "0.5")
private Double rotProb;
@NotNull(message = "Flip Probability는 필수입니다")
@Schema(example = "0.5")
private Double flipProb;
@NotBlank(message = "Rotation Degree는 필수입니다")
@Schema(example = "-20,20")
private String rotDegree;
@NotNull(message = "Exchange Probability는 필수입니다")
@Schema(example = "0.5")
private Double exchangeProb;
@NotNull(message = "Brightness Delta는 필수입니다")
@Schema(example = "10")
private Integer brightnessDelta;
@NotBlank(message = "Contrast Range는 필수입니다")
@Schema(example = "0.8,1.2")
private String contrastRange;
@NotBlank(message = "Saturation Range는 필수입니다")
@Schema(example = "0.8,1.2")
private String saturationRange;
@NotNull(message = "Hue Delta는 필수입니다")
@Schema(example = "10")
private Integer hueDelta;
// Legacy - 필수 아님 (예외 항목)
private Double dropoutRatio;
private Integer cnnFilterCnt;
// Common - 필수 아님 (예외 항목)
@Schema(example = "안녕하세요 캠코담당자 입니다. 하이퍼파라미터 신규등록합니다")
private String memo;
}
@Schema(name = "TrainStartReq", description = "학습 시작 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TrainStartReq {
@NotBlank(message = "하이퍼파라미터 버전은 필수입니다")
@Schema(example = "V3.99.251221.120518")
private String hyperVer;
@NotEmpty(message = "데이터셋은 최소 1개 이상 선택해야 합니다")
private List<Long> datasetIds;
@NotNull(message = "에폭 수는 필수입니다")
@jakarta.validation.constraints.Min(value = 1, message = "에폭 수는 최소 1 이상이어야 합니다")
@jakarta.validation.constraints.Max(value = 200, message = "에폭 수는 최대 200까지 설정 가능합니다")
@Schema(example = "200")
private Integer epoch;
@Schema(example = "7:2:1", description = "데이터 분할 비율 (Training:Validation:Test)")
private String datasetRatio;
}
@Schema(name = "TrainStartRes", description = "학습 시작 응답")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class TrainStartRes {
private String uuid;
private String status;
}
@Schema(name = "ResumeInfo", description = "학습 재시작 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class ResumeInfo {
private Boolean canResume;
private Integer lastEpoch;
private Integer totalEpoch;
private String checkpointPath;
@JsonFormatDttm private ZonedDateTime failedAt;
}
@Schema(name = "ResumeRequest", description = "학습 재시작 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ResumeRequest {
@NotNull(message = "재시작 Epoch는 필수입니다")
private Integer resumeFromEpoch;
private Integer newTotalEpoch;
}
@Schema(name = "ResumeResponse", description = "학습 재시작 응답")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class ResumeResponse {
private String uuid;
private String status;
private Integer resumedFromEpoch;
}
@Schema(name = "BestEpochRequest", description = "Best Epoch 설정 요청")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class BestEpochRequest {
@NotNull(message = "Best Epoch는 필수입니다")
private Integer bestEpoch;
private String reason;
}
@Schema(name = "BestEpochResponse", description = "Best Epoch 설정 응답")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class BestEpochResponse {
private String uuid;
private Integer bestEpoch;
private Integer confirmedBestEpoch;
private Integer previousBestEpoch;
}
@Schema(name = "EpochMetric", description = "Epoch별 성능 지표")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class EpochMetric {
private Integer epoch;
private Double mIoU;
private Double mFscore;
private Double loss;
private Boolean isBest;
}
}

View File

@@ -0,0 +1,61 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.ZonedDateTime;
import lombok.Getter;
public class ModelVerDto {
@Schema(name = "modelVer Basic", description = "모델버전 엔티티 기본 정보")
@Getter
public static class Basic {
private final Long id;
private final Long modelUid;
private final String modelCate;
private final String modelVer;
private final String usedState;
private final String modelState;
private final Double qualityProb;
private final String deployState;
private final String modelPath;
@JsonFormatDttm private final ZonedDateTime createdDttm;
private final Long createdUid;
@JsonFormatDttm private final ZonedDateTime updatedDttm;
private final Long updatedUid;
public Basic(
Long id,
Long modelUid,
String modelCate,
String modelVer,
String usedState,
String modelState,
Double qualityProb,
String deployState,
String modelPath,
ZonedDateTime createdDttm,
Long createdUid,
ZonedDateTime updatedDttm,
Long updatedUid) {
this.id = id;
this.modelUid = modelUid;
this.modelCate = modelCate;
this.modelVer = modelVer;
this.usedState = usedState;
this.modelState = modelState;
this.qualityProb = qualityProb;
this.deployState = deployState;
this.modelPath = modelPath;
this.createdDttm = createdDttm;
this.createdUid = createdUid;
this.updatedDttm = updatedDttm;
this.updatedUid = updatedUid;
}
}
}

View File

@@ -0,0 +1,50 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.model.dto.ModelMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.ModelMngCoreService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
@Slf4j
public class ModelMngService {
private final ModelMngCoreService modelMngCoreService;
/**
* 모델 목록 조회
*
* @param searchReq 검색 조건
* @return 페이징 처리된 모델 목록
*/
public Page<Basic> findByModels(SearchReq searchReq) {
return modelMngCoreService.findByModels(searchReq);
}
/**
* 모델 상세 조회
*
* @param modelUid 모델 UID
* @return 모델 상세 정보
*/
public ModelMngDto.Detail getModelDetail(Long modelUid) {
return modelMngCoreService.getModelDetail(modelUid);
}
/**
* 모델 상세 조회 (UUID 기반)
*
* @param uuid 모델 UUID
* @return 모델 상세 정보
*/
public ModelMngDto.Detail getModelDetailByUuid(String uuid) {
return modelMngCoreService.getModelDetailByUuid(uuid);
}
}

View File

@@ -0,0 +1,393 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.postgres.core.DatasetCoreService;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelMngCoreService;
import com.kamco.cd.training.postgres.core.SystemMetricsCoreService;
import com.kamco.cd.training.postgres.entity.ModelTrainMasterEntity;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
@Slf4j
public class ModelTrainService {
private final ModelMngCoreService modelMngCoreService;
private final HyperParamCoreService hyperParamCoreService;
private final DatasetCoreService datasetCoreService;
private final SystemMetricsCoreService systemMetricsCoreService;
/**
* 학습 모델 목록 조회
*
* @return 학습 모델 목록
*/
public List<ModelMngDto.TrainListRes> getTrainModelList() {
return modelMngCoreService.findAllTrainModels();
}
/**
* 학습 설정 통합 조회
*
* @return 학습 설정 폼 데이터
*/
public ModelMngDto.FormConfigRes getFormConfig() {
// 1. 현재 실행 중인 모델 확인
String runningModelUuid = modelMngCoreService.findRunningModelUuid();
boolean isTrainAvailable = (runningModelUuid == null);
// 2. 저장공간 체크 (10GB 미만 시 학습 불가)
if (isTrainAvailable) {
isTrainAvailable = systemMetricsCoreService.isStorageAvailableForTraining();
long availableMB = systemMetricsCoreService.getAvailableStorageMB();
log.info("저장공간 체크 완료: {}MB 사용 가능, 학습 가능 여부: {}", availableMB, isTrainAvailable);
}
// 3. 하이퍼파라미터 목록
List<ModelMngDto.HyperParamInfo> hyperParams = hyperParamCoreService.findAllActiveHyperParams();
// 4. 데이터셋 목록
List<ModelMngDto.DatasetInfo> datasets = datasetCoreService.findAllActiveDatasetsForTraining();
return ModelMngDto.FormConfigRes.builder()
.isTrainAvailable(isTrainAvailable)
.hyperParams(hyperParams)
.datasets(datasets)
.runningModelUuid(runningModelUuid)
.build();
}
/**
* 하이퍼파라미터 등록
*
* @param createReq 등록 요청
* @return 생성된 버전명
*/
@Transactional
public String createHyperParam(ModelMngDto.HyperParamCreateReq createReq) {
// 신규 버전 추가 시 baseHyperVer가 없으면 H1으로 설정
if (createReq.getBaseHyperVer() == null || createReq.getBaseHyperVer().isEmpty()) {
String firstVersion = hyperParamCoreService.getFirstHyperParamVersion();
createReq.setBaseHyperVer(firstVersion);
log.info("baseHyperVer가 없어 첫 번째 버전으로 설정: {}", firstVersion);
}
String newVersion = hyperParamCoreService.createHyperParam(createReq);
log.info("하이퍼파라미터 등록 완료: {}", newVersion);
return newVersion;
}
/**
* 하이퍼파라미터 단건 조회
*
* @param hyperVer 하이퍼파라미터 버전
* @return 하이퍼파라미터 정보
*/
public ModelMngDto.HyperParamInfo getHyperParam(String hyperVer) {
return hyperParamCoreService.findByHyperVer(hyperVer);
}
/**
* 하이퍼파라미터 삭제
*
* @param hyperVer 하이퍼파라미터 버전
*/
@Transactional
public void deleteHyperParam(String hyperVer) {
hyperParamCoreService.deleteHyperParam(hyperVer);
log.info("하이퍼파라미터 삭제 완료: {}", hyperVer);
}
/**
* 학습 시작
*
* @param trainReq 학습 시작 요청
* @return 학습 시작 응답
*/
@Transactional
public ModelMngDto.TrainStartRes startTraining(ModelMngDto.TrainStartReq trainReq) {
// 1. 동시 실행 방지 검증
String runningModelUuid = modelMngCoreService.findRunningModelUuid();
if (runningModelUuid != null) {
throw new BadRequestException(
"이미 실행 중인 학습이 있습니다. 학습은 한 번에 한 개만 실행할 수 있습니다. (실행 중인 모델: " + runningModelUuid + ")");
}
// 2. 저장공간 체크 (10GB 미만 시 학습 불가)
if (!systemMetricsCoreService.isStorageAvailableForTraining()) {
long availableMB = systemMetricsCoreService.getAvailableStorageMB();
long requiredMB = 10 * 1024; // 10GB
throw new BadRequestException(
String.format(
"저장공간이 부족하여 학습을 시작할 수 없습니다. (필요: %dMB, 사용 가능: %dMB)", requiredMB, availableMB));
}
// 3. 데이터셋 상태 검증 (COMPLETED 상태만 학습 가능)
validateDatasetStatus(trainReq.getDatasetIds());
// 4. 데이터 분할 비율 검증 (예: "7:2:1" 형식)
if (trainReq.getDatasetRatio() != null && !trainReq.getDatasetRatio().isEmpty()) {
validateDatasetRatio(trainReq.getDatasetRatio());
}
// 5. 학습 마스터 생성
ModelTrainMasterEntity entity = modelMngCoreService.createTrainMaster(trainReq);
// 5. 데이터셋 매핑 생성
modelMngCoreService.createDatasetMappings(entity.getId(), trainReq.getDatasetIds());
// 6. 실제 UUID 사용
String uuid = entity.getUuid().toString();
log.info(
"학습 시작: uuid={}, hyperVer={}, epoch={}, datasets={}",
uuid,
trainReq.getHyperVer(),
trainReq.getEpoch(),
trainReq.getDatasetIds());
// TODO: 비동기 GPU 학습 프로세스 트리거 로직 추가
return ModelMngDto.TrainStartRes.builder().uuid(uuid).status(entity.getStatusCd()).build();
}
/**
* 데이터셋 상태 검증
*
* @param datasetIds 데이터셋 ID 목록
*/
private void validateDatasetStatus(List<Long> datasetIds) {
for (Long datasetId : datasetIds) {
try {
var dataset = datasetCoreService.getOneById(datasetId);
// COMPLETED 상태가 아닌 데이터셋이 포함되어 있으면 예외 발생
if (dataset.getStatus() == null || !"COMPLETED".equals(dataset.getStatus())) {
throw new BadRequestException(
String.format(
"학습에 사용할 수 없는 데이터셋입니다. (ID: %d, 상태: %s). COMPLETED 상태의 데이터셋만 선택 가능합니다.",
datasetId, dataset.getStatus() != null ? dataset.getStatus() : "NULL"));
}
log.debug("데이터셋 상태 검증 통과: ID={}, Status={}", datasetId, dataset.getStatus());
} catch (NotFoundException e) {
throw new BadRequestException("존재하지 않는 데이터셋입니다. ID: " + datasetId);
}
}
log.info("모든 데이터셋 상태 검증 완료: {} 개", datasetIds.size());
}
/**
* 데이터 분할 비율 검증
*
* @param datasetRatio 데이터셋 비율 (예: "7:2:1")
*/
private void validateDatasetRatio(String datasetRatio) {
try {
String[] parts = datasetRatio.split(":");
if (parts.length != 3) {
throw new BadRequestException("데이터 분할 비율은 'Training:Validation:Test' 형식이어야 합니다 (예: 7:2:1)");
}
int train = Integer.parseInt(parts[0].trim());
int validation = Integer.parseInt(parts[1].trim());
int test = Integer.parseInt(parts[2].trim());
int sum = train + validation + test;
if (sum != 10) {
throw new BadRequestException(
String.format("데이터 분할 비율의 합계는 10이어야 합니다. (현재 합계: %d, 입력값: %s)", sum, datasetRatio));
}
if (train <= 0 || validation < 0 || test < 0) {
throw new BadRequestException("데이터 분할 비율은 모두 0 이상이어야 합니다 (Training은 1 이상)");
}
log.info(
"데이터 분할 비율 검증 완료: Training={}0%, Validation={}0%, Test={}0%", train, validation, test);
} catch (NumberFormatException e) {
throw new BadRequestException("데이터 분할 비율은 숫자로만 구성되어야 합니다: " + datasetRatio);
}
}
/**
* 학습 모델 삭제
*
* @param uuid 모델 UUID
*/
@Transactional
public void deleteTrainModel(String uuid) {
modelMngCoreService.deleteByUuid(uuid);
log.info("학습 모델 삭제 완료: uuid={}", uuid);
}
// ==================== Resume Training (학습 재시작) ====================
/**
* 학습 재시작 정보 조회
*
* @param uuid 모델 UUID
* @return 재시작 정보
*/
public ModelMngDto.ResumeInfo getResumeInfo(String uuid) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
return ModelMngDto.ResumeInfo.builder()
.canResume(entity.getCanResume() != null && entity.getCanResume())
.lastEpoch(entity.getLastCheckpointEpoch())
.totalEpoch(entity.getEpochCnt())
.checkpointPath(entity.getCheckpointPath())
.failedAt(
entity.getStopDttm() != null
? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
: null)
.build();
}
/**
* 학습 재시작
*
* @param uuid 모델 UUID
* @param resumeReq 재시작 요청
* @return 재시작 응답
*/
@Transactional
public ModelMngDto.ResumeResponse resumeTraining(
String uuid, ModelMngDto.ResumeRequest resumeReq) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// 재시작 가능 여부 검증
if (entity.getCanResume() == null || !entity.getCanResume()) {
throw new IllegalStateException("학습 재시작이 불가능한 모델입니다: " + uuid);
}
if (entity.getLastCheckpointEpoch() == null) {
throw new IllegalStateException("Checkpoint가 존재하지 않습니다: " + uuid);
}
// 상태 업데이트
entity.setStatusCd("RUNNING");
entity.setProgressRate(0);
// 총 Epoch 수 변경 (선택사항)
if (resumeReq.getNewTotalEpoch() != null) {
entity.setEpochCnt(resumeReq.getNewTotalEpoch());
}
log.info(
"학습 재시작: uuid={}, resumeFromEpoch={}, totalEpoch={}",
uuid,
resumeReq.getResumeFromEpoch(),
entity.getEpochCnt());
// TODO: 비동기 GPU 학습 재시작 프로세스 트리거 로직 추가
// - Checkpoint 파일 로드
// - 지정된 Epoch부터 학습 재개
return ModelMngDto.ResumeResponse.builder()
.uuid(uuid)
.status(entity.getStatusCd())
.resumedFromEpoch(resumeReq.getResumeFromEpoch())
.build();
}
// ==================== Best Epoch Setting (Best Epoch 설정) ====================
/**
* Best Epoch 설정
*
* @param uuid 모델 UUID
* @param bestEpochReq Best Epoch 요청
* @return Best Epoch 응답
*/
@Transactional
public ModelMngDto.BestEpochResponse setBestEpoch(
String uuid, ModelMngDto.BestEpochRequest bestEpochReq) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// 1차 학습 완료 상태 검증
if (!"STEP1_COMPLETED".equals(entity.getStatusCd())
&& !"STEP1".equals(entity.getProcessStep())) {
log.warn(
"Best Epoch 설정 시도: 현재 상태={}, processStep={}",
entity.getStatusCd(),
entity.getProcessStep());
}
Integer previousBestEpoch = entity.getConfirmedBestEpoch();
// 사용자가 확정한 Best Epoch 설정
entity.setConfirmedBestEpoch(bestEpochReq.getBestEpoch());
// 2차 학습(Test) 단계로 상태 전이
entity.setProcessStep("STEP2");
entity.setStatusCd("STEP2_RUNNING");
entity.setProgressRate(0);
entity.setUpdatedDttm(java.time.ZonedDateTime.now());
log.info(
"Best Epoch 설정 및 2차 학습 시작: uuid={}, newBestEpoch={}, previousBestEpoch={}, reason={}, newStatus={}",
uuid,
bestEpochReq.getBestEpoch(),
previousBestEpoch,
bestEpochReq.getReason(),
entity.getStatusCd());
// TODO: 비동기 GPU 2차 학습(Test) 프로세스 트리거 로직 추가
// - Best Epoch 모델 로드
// - Test 데이터셋으로 성능 평가 실행
// - 완료 시 STEP2_COMPLETED 상태로 전환
return ModelMngDto.BestEpochResponse.builder()
.uuid(uuid)
.bestEpoch(entity.getBestEpoch()) // 자동 선택된 값
.confirmedBestEpoch(entity.getConfirmedBestEpoch()) // 사용자 확정 값
.previousBestEpoch(previousBestEpoch)
.build();
}
/**
* Epoch별 성능 지표 조회
*
* @param uuid 모델 UUID
* @return Epoch별 성능 지표 목록
*/
public List<ModelMngDto.EpochMetric> getEpochMetrics(String uuid) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// TODO: 실제 학습 로그 파일이나 DB에서 Epoch별 성능 지표 조회
// 현재는 샘플 데이터 반환
List<ModelMngDto.EpochMetric> metrics = new java.util.ArrayList<>();
if (entity.getEpochCnt() != null && entity.getBestEpoch() != null) {
// 샘플 데이터 생성 (실제로는 학습 로그 파일 파싱 또는 별도 테이블 조회)
for (int i = 1; i <= Math.min(entity.getEpochCnt(), 10); i++) {
int epoch = entity.getBestEpoch() - 5 + i;
if (epoch <= 0 || epoch > entity.getEpochCnt()) {
continue;
}
metrics.add(
ModelMngDto.EpochMetric.builder()
.epoch(epoch)
.mIoU(0.80 + (Math.random() * 0.15)) // 샘플 데이터
.mFscore(0.85 + (Math.random() * 0.10)) // 샘플 데이터
.loss(0.3 - (Math.random() * 0.15)) // 샘플 데이터
.isBest(entity.getBestEpoch() != null && epoch == entity.getBestEpoch())
.build());
}
}
log.info("Epoch별 성능 지표 조회: uuid={}, metricsCount={}", uuid, metrics.size());
return metrics;
}
}