파일 count 기능 추가 #87
@@ -41,6 +41,7 @@ public class ModelTrainMngDto {
|
||||
private String trainType;
|
||||
private String modelNo;
|
||||
private Long currentAttemptId;
|
||||
private String requestPath;
|
||||
|
||||
public String getStatusName() {
|
||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.service.TmpDatasetService;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
@@ -127,6 +127,7 @@ public class ModelMasterEntity {
|
||||
this.statusCd,
|
||||
this.trainType,
|
||||
this.modelNo,
|
||||
this.currentAttemptId);
|
||||
this.currentAttemptId,
|
||||
this.requestPath);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.DataSetCountersService;
|
||||
import com.kamco.cd.training.train.service.TestJobService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
@@ -25,6 +26,7 @@ public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
private final TestJobService testJobService;
|
||||
private final DataSetCountersService dataSetCountersService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
@@ -45,7 +47,9 @@ public class TrainApiController {
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.createTmpFile(uuid);
|
||||
trainJobService.enqueue(modelId);
|
||||
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@@ -186,4 +190,27 @@ public class TrainApiController {
|
||||
|
||||
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "데이터셋 tmp 파일생성 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/counts/{uuid}")
|
||||
public ApiResponseDto<Void> getCount(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
dataSetCountersService.getCount(modelId);
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Log4j2
|
||||
@RequiredArgsConstructor
|
||||
public class DataSetCountersService {
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
@Value("${train.docker.basePath}")
|
||||
private String trainBaseDir;
|
||||
|
||||
public void getCount(Long modelId) {
|
||||
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
|
||||
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||
|
||||
try {
|
||||
// request 폴더
|
||||
for (String uid : uids) {
|
||||
Path path = Path.of(requestDir, uid);
|
||||
DatasetCounters counters = countTmpAfterBuild(path);
|
||||
counters.prints(uid, "REQUEST");
|
||||
}
|
||||
|
||||
// tmp
|
||||
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
|
||||
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
|
||||
counters2.prints(basic.getRequestPath(), "TMP");
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private int countTif(Path dir) throws IOException {
|
||||
if (!Files.isDirectory(dir)) return 0;
|
||||
|
||||
try (var stream = Files.walk(dir)) {
|
||||
return (int)
|
||||
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
|
||||
}
|
||||
}
|
||||
|
||||
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
|
||||
|
||||
// input1
|
||||
int in1Train = countTif(path.resolve("train/input1"));
|
||||
int in1Val = countTif(path.resolve("val/input1"));
|
||||
int in1Test = countTif(path.resolve("test/input1"));
|
||||
|
||||
// input2
|
||||
int in2Train = countTif(path.resolve("train/input2"));
|
||||
int in2Val = countTif(path.resolve("val/input2"));
|
||||
int in2Test = countTif(path.resolve("test/input2"));
|
||||
|
||||
List<DatasetCounter> input1List = new ArrayList<>();
|
||||
List<DatasetCounter> input2List = new ArrayList<>();
|
||||
|
||||
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
|
||||
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
|
||||
|
||||
return new DatasetCounters(input1List, input2List);
|
||||
}
|
||||
|
||||
@Getter
|
||||
public static class DatasetCounter {
|
||||
private int inputTrain = 0;
|
||||
private int inputTest = 0;
|
||||
private int inputVal = 0;
|
||||
|
||||
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
|
||||
this.inputTrain = inputTrain;
|
||||
this.inputTest = inputTest;
|
||||
this.inputVal = inputVal;
|
||||
}
|
||||
}
|
||||
|
||||
@Getter
|
||||
public static class DatasetCounters {
|
||||
private List<DatasetCounter> input1 = new ArrayList<>();
|
||||
private List<DatasetCounter> input2 = new ArrayList<>();
|
||||
|
||||
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
|
||||
this.input1 = input1;
|
||||
this.input2 = input2;
|
||||
}
|
||||
|
||||
public void prints(String uuid, String type) {
|
||||
int train = 0, test = 0, val = 0;
|
||||
int train2 = 0, test2 = 0, val2 = 0;
|
||||
|
||||
for (DatasetCounter datasetCounter : input1) {
|
||||
train += datasetCounter.inputTrain;
|
||||
test += datasetCounter.inputTest;
|
||||
val += datasetCounter.inputVal;
|
||||
}
|
||||
|
||||
for (DatasetCounter datasetCounter : input2) {
|
||||
train2 += datasetCounter.inputTrain;
|
||||
test2 += datasetCounter.inputTest;
|
||||
val2 += datasetCounter.inputVal;
|
||||
}
|
||||
|
||||
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
|
||||
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
|
||||
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
|
||||
log.info(
|
||||
"*total* = train : {} | val : {} | test : {} ",
|
||||
train + train2,
|
||||
val + val2,
|
||||
test + test2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
@@ -21,8 +20,8 @@ public class TestJobService {
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
@@ -30,6 +29,9 @@ public class TestJobService {
|
||||
// 마스터 확인
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// 폴더 카운트
|
||||
dataSetCounters.getCount(modelId);
|
||||
|
||||
// best epoch 업데이트
|
||||
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.model.service;
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
@@ -73,7 +73,6 @@ public class TmpDatasetService {
|
||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
|
||||
if (!Files.isDirectory(srcDir)) {
|
||||
log.warn("SKIP (not directory): {}", srcDir);
|
||||
noDir++;
|
||||
@@ -3,7 +3,6 @@ package com.kamco.cd.training.train.service;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.service.TmpDatasetService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
@@ -34,13 +33,13 @@ public class TrainJobService {
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final TmpDatasetService tmpDatasetService;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
public Long getModelIdByUuid(UUID uuid) {
|
||||
createTmpFile(uuid);
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
@@ -48,6 +47,9 @@ public class TrainJobService {
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
// 폴더 카운트
|
||||
dataSetCounters.getCount(modelId);
|
||||
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
@@ -235,7 +237,6 @@ public class TrainJobService {
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return modelUuid;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user