package com.kamco.cd.training.train; import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.train.service.DataSetCountersService; import com.kamco.cd.training.train.service.TestJobService; import com.kamco.cd.training.train.service.TrainJobService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.UUID; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API") @RequiredArgsConstructor @RestController @RequestMapping("/api/train") public class TrainApiController { private final TrainJobService trainJobService; private final TestJobService testJobService; private final DataSetCountersService dataSetCountersService; @Operation(summary = "학습 실행", description = "학습 실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "실행 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/run/{uuid}") public ApiResponseDto run( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); trainJobService.createTmpFile(uuid); trainJobService.enqueue(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 재실행", description = "학습 재실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "재실행 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/restart/{uuid}") public ApiResponseDto restart( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); Long jobId = trainJobService.restart(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 이어하기", description = "학습 이어하기 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "이어하기 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/resume/{uuid}") public ApiResponseDto resume( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); Long jobId = trainJobService.resume(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "학습 취소", description = "학습 취소 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "취소 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/cancel/{uuid}") public ApiResponseDto cancel( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); trainJobService.cancel(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "test 실행", description = "test 실행 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "test 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/test/run/{epoch}/{uuid}") public ApiResponseDto run( @Parameter(description = "best 에폭", example = "1") @PathVariable int epoch, @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); testJobService.enqueue(modelId, uuid, epoch); return ApiResponseDto.ok("ok"); } @Operation(summary = "test 학습 취소", description = "학습 취소 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "취소 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/test/cancel/{uuid}") public ApiResponseDto cancelTest( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); testJobService.cancel(modelId); return ApiResponseDto.ok("ok"); } @Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "데이터셋 tmp 파일생성 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/create-tmp/{uuid}") public ApiResponseDto createTmpFile( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { return ApiResponseDto.ok(trainJobService.createTmpFile(uuid)); } @Operation(summary = "getCount", description = "getCount 서버 로그확인") @ApiResponses( value = { @ApiResponse( responseCode = "200", description = "데이터셋 tmp 파일생성 성공", content = @Content( mediaType = "application/json", schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/counts/{uuid}") public ApiResponseDto getCount( @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); dataSetCountersService.getCount(modelId); return ApiResponseDto.ok(null); } }