diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java index f6e8396..149ebdf 100644 --- a/src/main/java/com/kamco/cd/training/train/TrainApiController.java +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -13,6 +13,8 @@ import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.UUID; import lombok.RequiredArgsConstructor; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; @@ -204,13 +206,12 @@ public class TrainApiController { @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) - @PostMapping("/counts/{uuid}") - public ApiResponseDto getCount( + @GetMapping(path = "/counts/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE) + public ApiResponseDto getCount( @Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49") @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); - dataSetCountersService.getCount(modelId); - return ApiResponseDto.ok(null); + return ApiResponseDto.ok(dataSetCountersService.getCount(modelId)); } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java b/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java index 7454b1a..82eeaaa 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java @@ -25,26 +25,32 @@ public class DataSetCountersService { @Value("${train.docker.basePath}") private String trainBaseDir; - public void getCount(Long modelId) { + public String getCount(Long modelId) { ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId); List datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId); List uids = modelTrainMngCoreService.findDatasetUid(datasetIds); + StringBuilder allLogs = new StringBuilder(); + try { // request 폴더 for (String uid : uids) { Path path = Path.of(requestDir, uid); DatasetCounters counters = countTmpAfterBuild(path); - counters.prints(uid, "REQUEST"); + allLogs.append(counters.prints(uid, "REQUEST")).append(System.lineSeparator()); } // tmp Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath()); DatasetCounters counters2 = countTmpAfterBuild(tmpPath); - counters2.prints(basic.getRequestPath(), "TMP"); + allLogs + .append(counters2.prints(basic.getRequestPath(), "TMP")) + .append(System.lineSeparator()); } catch (IOException e) { log.error(e.getMessage()); } + + return allLogs.toString(); } private int countTif(Path dir) throws IOException { @@ -114,7 +120,7 @@ public class DataSetCountersService { this.input2 = input2; } - public void prints(String uuid, String type) { + public String prints(String uuid, String type) { int train = 0, test = 0, val = 0; int train2 = 0, test2 = 0, val2 = 0; @@ -138,6 +144,23 @@ public class DataSetCountersService { train + train2, val + val2, test + test2); + + return String.format( + "======== UUID FOLDER COUNT %s : %s%n" + + "input 1 = train : %s | val : %s | test : %s%n" + + "input 2 = train : %s | val : %s | test : %s%n" + + "*total* = train : %s | val : %s | test : %s", + type, + uuid, + train, + val, + test, + train2, + val2, + test2, + train + train2, + val + val2, + test + test2); } } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index 9a3115a..0500a31 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -223,9 +223,9 @@ public class TrainJobService { UUID tmpUuid = UUID.randomUUID(); String raw = tmpUuid.toString().toUpperCase().replace("-", ""); + // MODELID 가져오기 Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid); List datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId); - List uids = modelTrainMngCoreService.findDatasetUid(datasetIds); try { @@ -233,6 +233,7 @@ public class TrainJobService { String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids); ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq(); updateReq.setRequestPath(pathUid); + // 학습모델을 수정한다. modelTrainMngCoreService.updateModelMaster(modelId, updateReq); } catch (IOException e) { throw new RuntimeException(e);