From 11d3afe295fa56b42f375d406ede515af0ad5ac8 Mon Sep 17 00:00:00 2001 From: teddy Date: Fri, 13 Feb 2026 10:38:24 +0900 Subject: [PATCH] =?UTF-8?q?=ED=8C=8C=EC=9D=BC=20count=20=EA=B8=B0=EB=8A=A5?= =?UTF-8?q?=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../training/model/dto/ModelTrainMngDto.java | 1 + .../model/service/ModelTrainMngService.java | 1 + .../postgres/entity/ModelMasterEntity.java | 3 +- .../cd/training/train/TrainApiController.java | 27 ++++ .../train/service/DataSetCountersService.java | 129 ++++++++++++++++++ .../train/service/TestJobService.java | 6 +- .../service/TmpDatasetService.java | 3 +- .../train/service/TrainJobService.java | 7 +- 8 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java rename src/main/java/com/kamco/cd/training/{model => train}/service/TmpDatasetService.java (99%) diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java index cabd5c6..216fe45 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -41,6 +41,7 @@ public class ModelTrainMngDto { private String trainType; private String modelNo; private Long currentAttemptId; + private String requestPath; public String getStatusName() { if (this.statusCd == null || this.statusCd.isBlank()) return null; diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index a5a4496..4d1f621 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq; import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.service.TmpDatasetService; import java.util.List; import java.util.UUID; import lombok.RequiredArgsConstructor; diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java index cd1124d..31b5b8b 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java @@ -127,6 +127,7 @@ public class ModelMasterEntity { this.statusCd, this.trainType, this.modelNo, - this.currentAttemptId); + this.currentAttemptId, + this.requestPath); } } diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java index dd8607b..88783d1 100644 --- a/src/main/java/com/kamco/cd/training/train/TrainApiController.java +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.train; import com.kamco.cd.training.config.api.ApiResponseDto; +import com.kamco.cd.training.train.service.DataSetCountersService; import com.kamco.cd.training.train.service.TestJobService; import com.kamco.cd.training.train.service.TrainJobService; import io.swagger.v3.oas.annotations.Operation; @@ -25,6 +26,7 @@ public class TrainApiController { private final TrainJobService trainJobService; private final TestJobService testJobService; + private final DataSetCountersService dataSetCountersService; @Operation(summary = "학습 실행", description = "학습 실행 API") @ApiResponses( @@ -45,7 +47,9 @@ public class TrainApiController { @PathVariable UUID uuid) { Long modelId = trainJobService.getModelIdByUuid(uuid); + trainJobService.createTmpFile(uuid); trainJobService.enqueue(modelId); + return ApiResponseDto.ok("ok"); } @@ -186,4 +190,27 @@ public class TrainApiController { return ApiResponseDto.ok(trainJobService.createTmpFile(uuid)); } + + @Operation(summary = "getCount", description = "getCount 서버 로그확인") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "데이터셋 tmp 파일생성 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @PostMapping("/counts/{uuid}") + public ApiResponseDto getCount( + @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + dataSetCountersService.getCount(modelId); + return ApiResponseDto.ok(null); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java b/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java new file mode 100644 index 0000000..0c90ee8 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/DataSetCountersService.java @@ -0,0 +1,129 @@ +package com.kamco.cd.training.train.service; + +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +@Service +@Log4j2 +@RequiredArgsConstructor +public class DataSetCountersService { + private final ModelTrainMngCoreService modelTrainMngCoreService; + + @Value("${train.docker.requestDir}") + private String requestDir; + + @Value("${train.docker.basePath}") + private String trainBaseDir; + + public void getCount(Long modelId) { + ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId); + List datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId); + List uids = modelTrainMngCoreService.findDatasetUid(datasetIds); + + try { + // request 폴더 + for (String uid : uids) { + Path path = Path.of(requestDir, uid); + DatasetCounters counters = countTmpAfterBuild(path); + counters.prints(uid, "REQUEST"); + } + + // tmp + Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath()); + DatasetCounters counters2 = countTmpAfterBuild(tmpPath); + counters2.prints(basic.getRequestPath(), "TMP"); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + + private int countTif(Path dir) throws IOException { + if (!Files.isDirectory(dir)) return 0; + + try (var stream = Files.walk(dir)) { + return (int) + stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count(); + } + } + + public DatasetCounters countTmpAfterBuild(Path path) throws IOException { + + // input1 + int in1Train = countTif(path.resolve("train/input1")); + int in1Val = countTif(path.resolve("val/input1")); + int in1Test = countTif(path.resolve("test/input1")); + + // input2 + int in2Train = countTif(path.resolve("train/input2")); + int in2Val = countTif(path.resolve("val/input2")); + int in2Test = countTif(path.resolve("test/input2")); + + List input1List = new ArrayList<>(); + List input2List = new ArrayList<>(); + + input1List.add(new DatasetCounter(in1Train, in1Test, in1Val)); + input2List.add(new DatasetCounter(in2Train, in2Test, in2Val)); + + return new DatasetCounters(input1List, input2List); + } + + @Getter + public static class DatasetCounter { + private int inputTrain = 0; + private int inputTest = 0; + private int inputVal = 0; + + public DatasetCounter(int inputTrain, int inputTest, int inputVal) { + this.inputTrain = inputTrain; + this.inputTest = inputTest; + this.inputVal = inputVal; + } + } + + @Getter + public static class DatasetCounters { + private List input1 = new ArrayList<>(); + private List input2 = new ArrayList<>(); + + public DatasetCounters(List input1, List input2) { + this.input1 = input1; + this.input2 = input2; + } + + public void prints(String uuid, String type) { + int train = 0, test = 0, val = 0; + int train2 = 0, test2 = 0, val2 = 0; + + for (DatasetCounter datasetCounter : input1) { + train += datasetCounter.inputTrain; + test += datasetCounter.inputTest; + val += datasetCounter.inputVal; + } + + for (DatasetCounter datasetCounter : input2) { + train2 += datasetCounter.inputTrain; + test2 += datasetCounter.inputTest; + val2 += datasetCounter.inputVal; + } + + log.info("======== UUID FOLDER COUNT {} : {}", type, uuid); + log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test); + log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2); + log.info( + "*total* = train : {} | val : {} | test : {} ", + train + train2, + val + val2, + test + test2); + } + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java index c34b188..d541072 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java @@ -1,6 +1,5 @@ package com.kamco.cd.training.train.service; -import com.fasterxml.jackson.databind.ObjectMapper; import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; @@ -21,8 +20,8 @@ public class TestJobService { private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService; private final DockerTrainService dockerTrainService; - private final ObjectMapper objectMapper; private final ApplicationEventPublisher eventPublisher; + private final DataSetCountersService dataSetCounters; @Transactional public Long enqueue(Long modelId, UUID uuid, int epoch) { @@ -30,6 +29,9 @@ public class TestJobService { // 마스터 확인 modelTrainMngCoreService.findModelById(modelId); + // 폴더 카운트 + dataSetCounters.getCount(modelId); + // best epoch 업데이트 modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch); diff --git a/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java b/src/main/java/com/kamco/cd/training/train/service/TmpDatasetService.java similarity index 99% rename from src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java rename to src/main/java/com/kamco/cd/training/train/service/TmpDatasetService.java index e0d96ed..516a403 100644 --- a/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TmpDatasetService.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.model.service; +package com.kamco.cd.training.train.service; import java.io.IOException; import java.nio.file.*; @@ -73,7 +73,6 @@ public class TmpDatasetService { for (String part : List.of("input1", "input2", "label", "label-json")) { Path srcDir = srcRoot.resolve(type).resolve(part); - if (!Files.isDirectory(srcDir)) { log.warn("SKIP (not directory): {}", srcDir); noDir++; diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index df9f876..9a3115a 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -3,7 +3,6 @@ package com.kamco.cd.training.train.service; import com.fasterxml.jackson.databind.ObjectMapper; import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.model.dto.ModelTrainMngDto; -import com.kamco.cd.training.model.service.TmpDatasetService; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.train.dto.ModelTrainJobDto; @@ -34,13 +33,13 @@ public class TrainJobService { private final ObjectMapper objectMapper; private final ApplicationEventPublisher eventPublisher; private final TmpDatasetService tmpDatasetService; + private final DataSetCountersService dataSetCounters; // 학습 결과가 저장될 호스트 디렉토리 @Value("${train.docker.responseDir}") private String responseDir; public Long getModelIdByUuid(UUID uuid) { - createTmpFile(uuid); return modelTrainMngCoreService.findModelIdByUuid(uuid); } @@ -48,6 +47,9 @@ public class TrainJobService { @Transactional public Long enqueue(Long modelId) { + // 폴더 카운트 + dataSetCounters.getCount(modelId); + // 마스터 존재 확인(없으면 예외) modelTrainMngCoreService.findModelById(modelId); @@ -235,7 +237,6 @@ public class TrainJobService { } catch (IOException e) { throw new RuntimeException(e); } - return modelUuid; } }