Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202
This commit is contained in:
@@ -41,6 +41,7 @@ public class ModelTrainMngDto {
|
|||||||
private String trainType;
|
private String trainType;
|
||||||
private String modelNo;
|
private String modelNo;
|
||||||
private Long currentAttemptId;
|
private Long currentAttemptId;
|
||||||
|
private String requestPath;
|
||||||
|
|
||||||
public String getStatusName() {
|
public String getStatusName() {
|
||||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
|||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.service.TmpDatasetService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ public class ModelMasterEntity {
|
|||||||
this.statusCd,
|
this.statusCd,
|
||||||
this.trainType,
|
this.trainType,
|
||||||
this.modelNo,
|
this.modelNo,
|
||||||
this.currentAttemptId);
|
this.currentAttemptId,
|
||||||
|
this.requestPath);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.train;
|
package com.kamco.cd.training.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||||
|
import com.kamco.cd.training.train.service.DataSetCountersService;
|
||||||
import com.kamco.cd.training.train.service.TestJobService;
|
import com.kamco.cd.training.train.service.TestJobService;
|
||||||
import com.kamco.cd.training.train.service.TrainJobService;
|
import com.kamco.cd.training.train.service.TrainJobService;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
@@ -25,6 +26,7 @@ public class TrainApiController {
|
|||||||
|
|
||||||
private final TrainJobService trainJobService;
|
private final TrainJobService trainJobService;
|
||||||
private final TestJobService testJobService;
|
private final TestJobService testJobService;
|
||||||
|
private final DataSetCountersService dataSetCountersService;
|
||||||
|
|
||||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||||
@ApiResponses(
|
@ApiResponses(
|
||||||
@@ -45,7 +47,9 @@ public class TrainApiController {
|
|||||||
@PathVariable
|
@PathVariable
|
||||||
UUID uuid) {
|
UUID uuid) {
|
||||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
trainJobService.createTmpFile(uuid);
|
||||||
trainJobService.enqueue(modelId);
|
trainJobService.enqueue(modelId);
|
||||||
|
|
||||||
return ApiResponseDto.ok("ok");
|
return ApiResponseDto.ok("ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,4 +190,27 @@ public class TrainApiController {
|
|||||||
|
|
||||||
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "데이터셋 tmp 파일생성 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/counts/{uuid}")
|
||||||
|
public ApiResponseDto<Void> getCount(
|
||||||
|
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
dataSetCountersService.getCount(modelId);
|
||||||
|
return ApiResponseDto.ok(null);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.log4j.Log4j2;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Log4j2
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class DataSetCountersService {
|
||||||
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
|
||||||
|
@Value("${train.docker.requestDir}")
|
||||||
|
private String requestDir;
|
||||||
|
|
||||||
|
@Value("${train.docker.basePath}")
|
||||||
|
private String trainBaseDir;
|
||||||
|
|
||||||
|
public void getCount(Long modelId) {
|
||||||
|
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||||
|
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// request 폴더
|
||||||
|
for (String uid : uids) {
|
||||||
|
Path path = Path.of(requestDir, uid);
|
||||||
|
DatasetCounters counters = countTmpAfterBuild(path);
|
||||||
|
counters.prints(uid, "REQUEST");
|
||||||
|
}
|
||||||
|
|
||||||
|
// tmp
|
||||||
|
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
|
||||||
|
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
|
||||||
|
counters2.prints(basic.getRequestPath(), "TMP");
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error(e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int countTif(Path dir) throws IOException {
|
||||||
|
if (!Files.isDirectory(dir)) return 0;
|
||||||
|
|
||||||
|
try (var stream = Files.walk(dir)) {
|
||||||
|
return (int)
|
||||||
|
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
대소문자 및 geojson 필요시
|
||||||
|
* try (var stream = Files.walk(dir)) {
|
||||||
|
return (int)
|
||||||
|
stream
|
||||||
|
.filter(Files::isRegularFile)
|
||||||
|
.filter(p -> {
|
||||||
|
String name = p.getFileName().toString().toLowerCase();
|
||||||
|
return name.endsWith(".tif") || name.endsWith(".geojson");
|
||||||
|
})
|
||||||
|
.count();
|
||||||
|
}
|
||||||
|
* */
|
||||||
|
}
|
||||||
|
|
||||||
|
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
|
||||||
|
|
||||||
|
// input1
|
||||||
|
int in1Train = countTif(path.resolve("train/input1"));
|
||||||
|
int in1Val = countTif(path.resolve("val/input1"));
|
||||||
|
int in1Test = countTif(path.resolve("test/input1"));
|
||||||
|
|
||||||
|
// input2
|
||||||
|
int in2Train = countTif(path.resolve("train/input2"));
|
||||||
|
int in2Val = countTif(path.resolve("val/input2"));
|
||||||
|
int in2Test = countTif(path.resolve("test/input2"));
|
||||||
|
|
||||||
|
List<DatasetCounter> input1List = new ArrayList<>();
|
||||||
|
List<DatasetCounter> input2List = new ArrayList<>();
|
||||||
|
|
||||||
|
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
|
||||||
|
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
|
||||||
|
|
||||||
|
return new DatasetCounters(input1List, input2List);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
public static class DatasetCounter {
|
||||||
|
private int inputTrain = 0;
|
||||||
|
private int inputTest = 0;
|
||||||
|
private int inputVal = 0;
|
||||||
|
|
||||||
|
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
|
||||||
|
this.inputTrain = inputTrain;
|
||||||
|
this.inputTest = inputTest;
|
||||||
|
this.inputVal = inputVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
public static class DatasetCounters {
|
||||||
|
private List<DatasetCounter> input1 = new ArrayList<>();
|
||||||
|
private List<DatasetCounter> input2 = new ArrayList<>();
|
||||||
|
|
||||||
|
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
|
||||||
|
this.input1 = input1;
|
||||||
|
this.input2 = input2;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void prints(String uuid, String type) {
|
||||||
|
int train = 0, test = 0, val = 0;
|
||||||
|
int train2 = 0, test2 = 0, val2 = 0;
|
||||||
|
|
||||||
|
for (DatasetCounter datasetCounter : input1) {
|
||||||
|
train += datasetCounter.inputTrain;
|
||||||
|
test += datasetCounter.inputTest;
|
||||||
|
val += datasetCounter.inputVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (DatasetCounter datasetCounter : input2) {
|
||||||
|
train2 += datasetCounter.inputTrain;
|
||||||
|
test2 += datasetCounter.inputTest;
|
||||||
|
val2 += datasetCounter.inputVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
|
||||||
|
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
|
||||||
|
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
|
||||||
|
log.info(
|
||||||
|
"*total* = train : {} | val : {} | test : {} ",
|
||||||
|
train + train2,
|
||||||
|
val + val2,
|
||||||
|
test + test2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
@@ -21,8 +20,8 @@ public class TestJobService {
|
|||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
private final DockerTrainService dockerTrainService;
|
private final DockerTrainService dockerTrainService;
|
||||||
private final ObjectMapper objectMapper;
|
|
||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
|
private final DataSetCountersService dataSetCounters;
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||||
@@ -30,6 +29,9 @@ public class TestJobService {
|
|||||||
// 마스터 확인
|
// 마스터 확인
|
||||||
modelTrainMngCoreService.findModelById(modelId);
|
modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
// 폴더 카운트
|
||||||
|
dataSetCounters.getCount(modelId);
|
||||||
|
|
||||||
// best epoch 업데이트
|
// best epoch 업데이트
|
||||||
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.model.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.file.*;
|
import java.nio.file.*;
|
||||||
@@ -73,7 +73,6 @@ public class TmpDatasetService {
|
|||||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||||
|
|
||||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||||
|
|
||||||
if (!Files.isDirectory(srcDir)) {
|
if (!Files.isDirectory(srcDir)) {
|
||||||
log.warn("SKIP (not directory): {}", srcDir);
|
log.warn("SKIP (not directory): {}", srcDir);
|
||||||
noDir++;
|
noDir++;
|
||||||
@@ -3,7 +3,6 @@ package com.kamco.cd.training.train.service;
|
|||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.model.service.TmpDatasetService;
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
@@ -34,13 +33,13 @@ public class TrainJobService {
|
|||||||
private final ObjectMapper objectMapper;
|
private final ObjectMapper objectMapper;
|
||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
private final TmpDatasetService tmpDatasetService;
|
private final TmpDatasetService tmpDatasetService;
|
||||||
|
private final DataSetCountersService dataSetCounters;
|
||||||
|
|
||||||
// 학습 결과가 저장될 호스트 디렉토리
|
// 학습 결과가 저장될 호스트 디렉토리
|
||||||
@Value("${train.docker.responseDir}")
|
@Value("${train.docker.responseDir}")
|
||||||
private String responseDir;
|
private String responseDir;
|
||||||
|
|
||||||
public Long getModelIdByUuid(UUID uuid) {
|
public Long getModelIdByUuid(UUID uuid) {
|
||||||
createTmpFile(uuid);
|
|
||||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,6 +47,9 @@ public class TrainJobService {
|
|||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId) {
|
public Long enqueue(Long modelId) {
|
||||||
|
|
||||||
|
// 폴더 카운트
|
||||||
|
dataSetCounters.getCount(modelId);
|
||||||
|
|
||||||
// 마스터 존재 확인(없으면 예외)
|
// 마스터 존재 확인(없으면 예외)
|
||||||
modelTrainMngCoreService.findModelById(modelId);
|
modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
@@ -235,7 +237,6 @@ public class TrainJobService {
|
|||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelUuid;
|
return modelUuid;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user