Files
kamco-train-api/src/main/java/com/kamco/cd/training/train/TrainApiController.java
2026-02-13 10:38:24 +09:00

217 lines
8.8 KiB
Java

package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.DataSetCountersService;
import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/train")
public class TrainApiController {
private final TrainJobService trainJobService;
private final TestJobService testJobService;
private final DataSetCountersService dataSetCountersService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/run/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.createTmpFile(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "재실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/restart/{uuid}")
public ApiResponseDto<String> restart(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.restart(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "이어하기 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/resume/{uuid}")
public ApiResponseDto<String> resume(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.resume(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/cancel/{uuid}")
public ApiResponseDto<String> cancel(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 실행", description = "test 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "test 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/run/{epoch}/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.enqueue(modelId, uuid, epoch);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/cancel/{uuid}")
public ApiResponseDto<String> cancelTest(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/create-tmp/{uuid}")
public ApiResponseDto<UUID> createTmpFile(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
}
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/counts/{uuid}")
public ApiResponseDto<Void> getCount(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
dataSetCountersService.getCount(modelId);
return ApiResponseDto.ok(null);
}
}