217 lines
8.8 KiB
Java
217 lines
8.8 KiB
Java
package com.kamco.cd.training.train;
|
|
|
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
|
import com.kamco.cd.training.train.service.DataSetCountersService;
|
|
import com.kamco.cd.training.train.service.TestJobService;
|
|
import com.kamco.cd.training.train.service.TrainJobService;
|
|
import io.swagger.v3.oas.annotations.Operation;
|
|
import io.swagger.v3.oas.annotations.Parameter;
|
|
import io.swagger.v3.oas.annotations.media.Content;
|
|
import io.swagger.v3.oas.annotations.media.Schema;
|
|
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
|
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
import java.util.UUID;
|
|
import lombok.RequiredArgsConstructor;
|
|
import org.springframework.web.bind.annotation.PathVariable;
|
|
import org.springframework.web.bind.annotation.PostMapping;
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
|
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
|
|
@RequiredArgsConstructor
|
|
@RestController
|
|
@RequestMapping("/api/train")
|
|
public class TrainApiController {
|
|
|
|
private final TrainJobService trainJobService;
|
|
private final TestJobService testJobService;
|
|
private final DataSetCountersService dataSetCountersService;
|
|
|
|
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "실행 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/run/{uuid}")
|
|
public ApiResponseDto<String> run(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
trainJobService.createTmpFile(uuid);
|
|
trainJobService.enqueue(modelId);
|
|
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "재실행 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/restart/{uuid}")
|
|
public ApiResponseDto<String> restart(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
Long jobId = trainJobService.restart(modelId);
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "이어하기 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/resume/{uuid}")
|
|
public ApiResponseDto<String> resume(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
Long jobId = trainJobService.resume(modelId);
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "취소 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/cancel/{uuid}")
|
|
public ApiResponseDto<String> cancel(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
trainJobService.cancel(modelId);
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "test 실행", description = "test 실행 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "test 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/test/run/{epoch}/{uuid}")
|
|
public ApiResponseDto<String> run(
|
|
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
testJobService.enqueue(modelId, uuid, epoch);
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "test 학습 취소", description = "학습 취소 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "취소 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/test/cancel/{uuid}")
|
|
public ApiResponseDto<String> cancelTest(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
testJobService.cancel(modelId);
|
|
return ApiResponseDto.ok("ok");
|
|
}
|
|
|
|
@Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "데이터셋 tmp 파일생성 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/create-tmp/{uuid}")
|
|
public ApiResponseDto<UUID> createTmpFile(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
|
|
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
|
}
|
|
|
|
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
|
|
@ApiResponses(
|
|
value = {
|
|
@ApiResponse(
|
|
responseCode = "200",
|
|
description = "데이터셋 tmp 파일생성 성공",
|
|
content =
|
|
@Content(
|
|
mediaType = "application/json",
|
|
schema = @Schema(implementation = String.class))),
|
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
|
})
|
|
@PostMapping("/counts/{uuid}")
|
|
public ApiResponseDto<Void> getCount(
|
|
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
|
@PathVariable
|
|
UUID uuid) {
|
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
|
dataSetCountersService.getCount(modelId);
|
|
return ApiResponseDto.ok(null);
|
|
}
|
|
}
|