Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202

This commit is contained in:
2026-02-13 10:50:35 +09:00
8 changed files with 183 additions and 8 deletions

View File

@@ -41,6 +41,7 @@ public class ModelTrainMngDto {
private String trainType;
private String modelNo;
private Long currentAttemptId;
private String requestPath;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;

View File

@@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.service.TmpDatasetService;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;

View File

@@ -127,6 +127,7 @@ public class ModelMasterEntity {
this.statusCd,
this.trainType,
this.modelNo,
this.currentAttemptId);
this.currentAttemptId,
this.requestPath);
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.DataSetCountersService;
import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
@@ -25,6 +26,7 @@ public class TrainApiController {
private final TrainJobService trainJobService;
private final TestJobService testJobService;
private final DataSetCountersService dataSetCountersService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
@@ -45,7 +47,9 @@ public class TrainApiController {
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.createTmpFile(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
@@ -186,4 +190,27 @@ public class TrainApiController {
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
}
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/counts/{uuid}")
public ApiResponseDto<Void> getCount(
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
dataSetCountersService.getCount(modelId);
return ApiResponseDto.ok(null);
}
}

View File

@@ -0,0 +1,143 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Service
@Log4j2
@RequiredArgsConstructor
public class DataSetCountersService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
@Value("${train.docker.requestDir}")
private String requestDir;
@Value("${train.docker.basePath}")
private String trainBaseDir;
public void getCount(Long modelId) {
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try {
// request 폴더
for (String uid : uids) {
Path path = Path.of(requestDir, uid);
DatasetCounters counters = countTmpAfterBuild(path);
counters.prints(uid, "REQUEST");
}
// tmp
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
counters2.prints(basic.getRequestPath(), "TMP");
} catch (IOException e) {
log.error(e.getMessage());
}
}
private int countTif(Path dir) throws IOException {
if (!Files.isDirectory(dir)) return 0;
try (var stream = Files.walk(dir)) {
return (int)
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
}
/*
대소문자 및 geojson 필요시
* try (var stream = Files.walk(dir)) {
return (int)
stream
.filter(Files::isRegularFile)
.filter(p -> {
String name = p.getFileName().toString().toLowerCase();
return name.endsWith(".tif") || name.endsWith(".geojson");
})
.count();
}
* */
}
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
// input1
int in1Train = countTif(path.resolve("train/input1"));
int in1Val = countTif(path.resolve("val/input1"));
int in1Test = countTif(path.resolve("test/input1"));
// input2
int in2Train = countTif(path.resolve("train/input2"));
int in2Val = countTif(path.resolve("val/input2"));
int in2Test = countTif(path.resolve("test/input2"));
List<DatasetCounter> input1List = new ArrayList<>();
List<DatasetCounter> input2List = new ArrayList<>();
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
return new DatasetCounters(input1List, input2List);
}
@Getter
public static class DatasetCounter {
private int inputTrain = 0;
private int inputTest = 0;
private int inputVal = 0;
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
this.inputTrain = inputTrain;
this.inputTest = inputTest;
this.inputVal = inputVal;
}
}
@Getter
public static class DatasetCounters {
private List<DatasetCounter> input1 = new ArrayList<>();
private List<DatasetCounter> input2 = new ArrayList<>();
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
this.input1 = input1;
this.input2 = input2;
}
public void prints(String uuid, String type) {
int train = 0, test = 0, val = 0;
int train2 = 0, test2 = 0, val2 = 0;
for (DatasetCounter datasetCounter : input1) {
train += datasetCounter.inputTrain;
test += datasetCounter.inputTest;
val += datasetCounter.inputVal;
}
for (DatasetCounter datasetCounter : input2) {
train2 += datasetCounter.inputTrain;
test2 += datasetCounter.inputTest;
val2 += datasetCounter.inputVal;
}
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
log.info(
"*total* = train : {} | val : {} | test : {} ",
train + train2,
val + val2,
test + test2);
}
}
}

View File

@@ -1,6 +1,5 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
@@ -21,8 +20,8 @@ public class TestJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
@@ -30,6 +29,9 @@ public class TestJobService {
// 마스터 확인
modelTrainMngCoreService.findModelById(modelId);
// 폴더 카운트
dataSetCounters.getCount(modelId);
// best epoch 업데이트
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.model.service;
package com.kamco.cd.training.train.service;
import java.io.IOException;
import java.nio.file.*;
@@ -73,7 +73,6 @@ public class TmpDatasetService {
for (String part : List.of("input1", "input2", "label", "label-json")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
if (!Files.isDirectory(srcDir)) {
log.warn("SKIP (not directory): {}", srcDir);
noDir++;

View File

@@ -3,7 +3,6 @@ package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.service.TmpDatasetService;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
@@ -34,13 +33,13 @@ public class TrainJobService {
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
private final TmpDatasetService tmpDatasetService;
private final DataSetCountersService dataSetCounters;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
public Long getModelIdByUuid(UUID uuid) {
createTmpFile(uuid);
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
@@ -48,6 +47,9 @@ public class TrainJobService {
@Transactional
public Long enqueue(Long modelId) {
// 폴더 카운트
dataSetCounters.getCount(modelId);
// 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId);
@@ -235,7 +237,6 @@ public class TrainJobService {
} catch (IOException e) {
throw new RuntimeException(e);
}
return modelUuid;
}
}