287 lines
14 KiB
Java
287 lines
14 KiB
Java
package com.kamco.cd.training.model;
|
|
|
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
|
import com.kamco.cd.training.model.dto.ModelMngDto;
|
|
import com.kamco.cd.training.model.dto.ModelMngDto.Basic;
|
|
import com.kamco.cd.training.model.service.ModelMngService;
|
|
import com.kamco.cd.training.model.service.ModelTrainService;
|
|
import io.swagger.v3.oas.annotations.Operation;
|
|
import io.swagger.v3.oas.annotations.Parameter;
|
|
import io.swagger.v3.oas.annotations.media.Content;
|
|
import io.swagger.v3.oas.annotations.media.Schema;
|
|
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
|
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
import jakarta.validation.Valid;
|
|
import java.util.List;
|
|
import lombok.RequiredArgsConstructor;
|
|
import org.springframework.data.domain.Page;
|
|
import org.springframework.web.bind.annotation.DeleteMapping;
|
|
import org.springframework.web.bind.annotation.GetMapping;
|
|
import org.springframework.web.bind.annotation.PathVariable;
|
|
import org.springframework.web.bind.annotation.PostMapping;
|
|
import org.springframework.web.bind.annotation.RequestBody;
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
import org.springframework.web.bind.annotation.RequestParam;
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
|
@RestController
|
|
@RequiredArgsConstructor
|
|
@Tag(name = "모델관리", description = "어드민 홈 > 모델학습관리 > 모델관리 > 목록")
|
|
@RequestMapping("/api/models")
|
|
public class ModelMngApiController {
|
|
private final ModelMngService modelMngService;
|
|
private final ModelTrainService modelTrainService;
|
|
|
|
@Operation(summary = "학습 모델 목록 조회", description = "학습 모델 목록을 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "검색 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = Page.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping
|
|
public ApiResponseDto<Page<Basic>> findByModels(
|
|
@Parameter(description = "상태 코드") @RequestParam(required = false) String status,
|
|
@Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page,
|
|
@Parameter(description = "페이지 크기") @RequestParam(defaultValue = "20") int size) {
|
|
ModelMngDto.SearchReq searchReq = new ModelMngDto.SearchReq(status, page, size);
|
|
return ApiResponseDto.ok(modelMngService.findByModels(searchReq));
|
|
}
|
|
|
|
@Operation(summary = "학습 모델 상세 조회", description = "학습 모델의 상세 정보를 UUID로 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.Detail.class))),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/{uuid}")
|
|
public ApiResponseDto<ModelMngDto.Detail> getModelDetail(
|
|
@Parameter(description = "모델 UUID", example = "b7e99739-6736-45f9-a224-8161ecddf287")
|
|
@PathVariable
|
|
String uuid) {
|
|
return ApiResponseDto.ok(modelMngService.getModelDetailByUuid(uuid));
|
|
}
|
|
|
|
// ==================== 학습 모델학습관리 API (5종) ====================
|
|
|
|
@Operation(summary = "학습 모델 통합 조회", description = "학습 관리 화면에서 학습 이력 리스트와 현재 상태를 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = List.class))),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/train")
|
|
public ApiResponseDto<List<ModelMngDto.TrainListRes>> getTrainModelList() {
|
|
return ApiResponseDto.ok(modelTrainService.getTrainModelList());
|
|
}
|
|
|
|
@Operation(summary = "학습 설정 통합 조회", description = "학습 실행 팝업 구성에 필요한 모든 데이터를 한 번에 반환합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.FormConfigRes.class))),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/train/form-config")
|
|
public ApiResponseDto<ModelMngDto.FormConfigRes> getFormConfig() {
|
|
return ApiResponseDto.ok(modelTrainService.getFormConfig());
|
|
}
|
|
|
|
@Operation(summary = "하이퍼파라미터 등록", description = "Step 1 에서 파라미터를 수정하여 신규 버전으로 저장합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "등록 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/hyper-params")
|
|
public ApiResponseDto<String> createHyperParam(
|
|
@Valid @RequestBody ModelMngDto.HyperParamCreateReq createReq) {
|
|
String newVersion = modelTrainService.createHyperParam(createReq);
|
|
return ApiResponseDto.ok(newVersion);
|
|
}
|
|
|
|
@Operation(summary = "하이퍼파라미터 단건 조회", description = "특정 버전의 하이퍼파라미터 상세 정보를 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.HyperParamInfo.class))),
|
|
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/hyper-params/{hyperVer}")
|
|
public ApiResponseDto<ModelMngDto.HyperParamInfo> getHyperParam(
|
|
@Parameter(description = "하이퍼파라미터 버전", example = "H1") @PathVariable String hyperVer) {
|
|
return ApiResponseDto.ok(modelTrainService.getHyperParam(hyperVer));
|
|
}
|
|
|
|
@Operation(summary = "하이퍼파라미터 삭제", description = "특정 버전의 하이퍼파라미터를 삭제합니다 (H1은 삭제 불가)")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
|
|
@ApiResponse(responseCode = "400", description = "H1은 삭제 불가", content = @Content),
|
|
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@DeleteMapping("/hyper-params/{hyperVer}")
|
|
public ApiResponseDto<Void> deleteHyperParam(
|
|
@Parameter(description = "하이퍼파라미터 버전", example = "V3.99.251221.120518") @PathVariable
|
|
String hyperVer) {
|
|
modelTrainService.deleteHyperParam(hyperVer);
|
|
return ApiResponseDto.ok(null);
|
|
}
|
|
|
|
@Operation(summary = "학습 시작", description = "모든 설정(Step 1~3)을 마치고 최종적으로 학습 프로세스를 시작합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "학습 시작 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.TrainStartRes.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/train")
|
|
public ApiResponseDto<ModelMngDto.TrainStartRes> startTraining(
|
|
@Valid @RequestBody ModelMngDto.TrainStartReq trainReq) {
|
|
return ApiResponseDto.ok(modelTrainService.startTraining(trainReq));
|
|
}
|
|
|
|
@Operation(summary = "학습 모델 삭제", description = "목록에서 특정 학습 모델을 삭제합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
|
|
@ApiResponse(responseCode = "400", description = "진행 중인 모델은 삭제 불가", content = @Content),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@DeleteMapping("/train/{uuid}")
|
|
public ApiResponseDto<Void> deleteTrainModel(
|
|
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
|
|
modelTrainService.deleteTrainModel(uuid);
|
|
return ApiResponseDto.ok(null);
|
|
}
|
|
|
|
// ==================== Resume Training (학습 재시작) ====================
|
|
|
|
@Operation(summary = "학습 재시작 정보 조회", description = "중단된 학습의 재시작 가능 여부와 Checkpoint 정보를 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.ResumeInfo.class))),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/train/{uuid}/resume-info")
|
|
public ApiResponseDto<ModelMngDto.ResumeInfo> getResumeInfo(
|
|
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
|
|
return ApiResponseDto.ok(modelTrainService.getResumeInfo(uuid));
|
|
}
|
|
|
|
@Operation(summary = "학습 재시작", description = "중단된 지점(Checkpoint)부터 학습을 재개합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "재시작 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.ResumeResponse.class))),
|
|
@ApiResponse(responseCode = "400", description = "재시작 불가능한 상태", content = @Content),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/train/{uuid}/resume")
|
|
public ApiResponseDto<ModelMngDto.ResumeResponse> resumeTraining(
|
|
@Parameter(description = "모델 UUID") @PathVariable String uuid,
|
|
@Valid @RequestBody ModelMngDto.ResumeRequest resumeReq) {
|
|
return ApiResponseDto.ok(modelTrainService.resumeTraining(uuid, resumeReq));
|
|
}
|
|
|
|
// ==================== Best Epoch Setting (Best Epoch 설정) ====================
|
|
|
|
@Operation(summary = "Best Epoch 설정", description = "사용자가 직접 Best Epoch를 선택하여 설정합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "설정 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = ModelMngDto.BestEpochResponse.class))),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/train/{uuid}/best-epoch")
|
|
public ApiResponseDto<ModelMngDto.BestEpochResponse> setBestEpoch(
|
|
@Parameter(description = "모델 UUID") @PathVariable String uuid,
|
|
@Valid @RequestBody ModelMngDto.BestEpochRequest bestEpochReq) {
|
|
return ApiResponseDto.ok(modelTrainService.setBestEpoch(uuid, bestEpochReq));
|
|
}
|
|
|
|
@Operation(summary = "Epoch별 성능 지표 조회", description = "학습된 모델의 Epoch별 성능 지표를 조회합니다")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "조회 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = List.class))),
|
|
@ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@GetMapping("/train/{uuid}/epoch-metrics")
|
|
public ApiResponseDto<List<ModelMngDto.EpochMetric>> getEpochMetrics(
|
|
@Parameter(description = "모델 UUID") @PathVariable String uuid) {
|
|
return ApiResponseDto.ok(modelTrainService.getEpochMetrics(uuid));
|
|
}
|
|
}
|