package com.kamco.cd.training.train; import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.train.dto.TrainingMetricsDto; import com.kamco.cd.training.train.dto.TrainingProgressDto; import com.kamco.cd.training.train.service.DataSetCountersService; import com.kamco.cd.training.train.service.TestJobService; import com.kamco.cd.training.train.service.TrainJobService; import com.kamco.cd.training.train.service.TrainingMetricsService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.UUID; import lombok.RequiredArgsConstructor; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API") @RequiredArgsConstructor @RestController @RequestMapping("/api/train") public class TrainApiController { private final TrainJobService trainJobService; private final TestJobService testJobService; private final DataSetCountersService dataSetCountersService; private final TrainingMetricsService trainingMetricsService; @Operation(summary = "학습 실행", description = "학습 실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "실행 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/run/{uuid}") public ApiResponseDto run( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); trainJobService.enqueue(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 재실행", description = "학습 재실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "재실행 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/restart/{uuid}") public ApiResponseDto restart( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); Long jobId = trainJobService.restart(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 이어하기", description = "학습 이어하기 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "이어하기 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/resume/{uuid}") public ApiResponseDto resume( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); Long jobId = trainJobService.resume(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 취소", description = "학습 취소 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "취소 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/cancel/{uuid}") public ApiResponseDto cancel( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); trainJobService.cancel(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "test 실행", description = "test 실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "test 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/test/run/{epoch}/{uuid}") public ApiResponseDto run( @Parameter(description = "best 에폭", example = "1") @PathVariable int epoch, @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); testJobService.enqueue(modelId, uuid, epoch); return ApiResponseDto.ok("ok"); } @Operation(summary = "test 학습 취소", description = "학습 취소 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "취소 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/test/cancel/{uuid}") public ApiResponseDto cancelTest( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); testJobService.cancel(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "데이터셋 tmp 파일생성 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/create-tmp/{uuid}") public ApiResponseDto createTmpFile( @Parameter(description = "model uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { return ApiResponseDto.ok(trainJobService.createTmpFile(uuid)); } @Operation(summary = "getCount", description = "getCount 서버 로그확인") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "데이터셋 tmp 파일생성 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @GetMapping(path = "/counts/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE) public ApiResponseDto getCount( @Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); return ApiResponseDto.ok(dataSetCountersService.getCount(modelId)); } @Operation(summary = "학습 상태 확인", description = "학습 상태 확인") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "학습 상태 변경", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping(path = "/status/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE) public ApiResponseDto status( @Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); trainJobService.status(uuid, modelId); return ApiResponseDto.ok("ok"); } @Operation( summary = "하이퍼파라미터 기반 학습 메트릭 조회", description = "하이퍼파라미터 UUID로 해당 파라미터를 사용하는 모델의 학습 메트릭을 조회합니다. " + "val.csv와 train.csv를 우선 사용하며, 없을 경우 processing.log를 파싱합니다.") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "조회 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = TrainingMetricsDto.Response.class))), @ApiResponse( responseCode = "404", description = "하이퍼파라미터 또는 모델을 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @GetMapping( path = "/metrics/hyper-param/{hyperParamUuid}", produces = MediaType.APPLICATION_JSON_VALUE) public ApiResponseDto getMetricsByHyperParam( @Parameter(description = "하이퍼파라미터 UUID", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10") @PathVariable UUID hyperParamUuid) { TrainingMetricsDto.Response response = trainingMetricsService.getTrainingMetricsByHyperParam(hyperParamUuid); return ApiResponseDto.ok(response); } @Operation( summary = "모델 기반 학습 메트릭 조회", description = "모델 UUID로 해당 모델의 학습 메트릭을 조회합니다. " + "val.csv와 train.csv를 우선 사용하며, 없을 경우 processing.log를 파싱합니다.") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "조회 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = TrainingMetricsDto.Response.class))), @ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @GetMapping(path = "/metrics/model/{modelUuid}", produces = MediaType.APPLICATION_JSON_VALUE) public ApiResponseDto getMetricsByModel( @Parameter(description = "모델 UUID", example = "b34a2d18-11e6-4b1b-a156-cd314bec45bb") @PathVariable UUID modelUuid) { TrainingMetricsDto.Response response = trainingMetricsService.getTrainingMetricsByModelUuid(modelUuid); return ApiResponseDto.ok(response); } @Operation( summary = "학습 진행률 조회", description = "UUID로 학습 진행률을 실시간 조회합니다. 기존 DB 구조를 활용하여 진행률을 계산합니다.") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "조회 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = TrainingProgressDto.class))), @ApiResponse(responseCode = "404", description = "모델을 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @GetMapping("/progress/{uuid}") public ApiResponseDto getTrainingProgress( @Parameter(description = "모델 UUID", required = true) @PathVariable UUID uuid) { TrainingProgressDto progress = trainJobService.getTrainingProgress(uuid); return ApiResponseDto.ok(progress); } }