Merge pull request 'feat/training_260202' (#111) from feat/training_260202 into develop
Reviewed-on: #111
This commit was merged in pull request #111.
This commit is contained in:
@@ -6,6 +6,8 @@ import com.jcraft.jsch.ChannelExec;
|
||||
import com.jcraft.jsch.ChannelSftp;
|
||||
import com.jcraft.jsch.JSch;
|
||||
import com.jcraft.jsch.Session;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto.ApiResponseCode;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
@@ -39,6 +41,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.geotools.coverage.grid.GridCoverage2D;
|
||||
import org.geotools.gce.geotiff.GeoTiffReader;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.util.FileSystemUtils;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
@@ -757,6 +760,11 @@ public class FIleChecker {
|
||||
zipEntry = zis.getNextEntry();
|
||||
}
|
||||
zis.closeEntry();
|
||||
} catch (IOException e) {
|
||||
throw new CustomApiException(
|
||||
ApiResponseCode.INTERNAL_SERVER_ERROR.getId(),
|
||||
HttpStatus.INTERNAL_SERVER_ERROR,
|
||||
"압축 해제 중 오류가 발생했습니다: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,10 +26,12 @@ import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -56,6 +58,7 @@ public class DatasetService {
|
||||
private String datasetDir;
|
||||
|
||||
private static final List<String> LABEL_DIRS = List.of("label-json", "label", "input1", "input2");
|
||||
private static final List<String> REQUIRED_DIRS = Arrays.asList("train", "val", "test");
|
||||
|
||||
/**
|
||||
* 데이터셋 목록 조회
|
||||
@@ -164,44 +167,6 @@ public class DatasetService {
|
||||
}
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
@Transactional
|
||||
public ResponseObj insertDatasetTo86(@Valid AddReq addReq) {
|
||||
|
||||
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
|
||||
|
||||
// 압축 해제
|
||||
FIleChecker.unzipOn86Server(
|
||||
addReq.getFilePath() + addReq.getFileName(),
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
|
||||
|
||||
// 해제한 폴더 읽어서 데이터 저장
|
||||
List<Map<String, Object>> list =
|
||||
getUnzipDatasetFilesTo86(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "train");
|
||||
|
||||
int idx = 0;
|
||||
for (Map<String, Object> map : list) {
|
||||
datasetUid =
|
||||
this.insertTrainTestData(map, addReq, idx, datasetUid, "train"); // train 데이터 insert
|
||||
idx++;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> testList =
|
||||
getUnzipDatasetFilesTo86(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
|
||||
|
||||
int testIdx = 0;
|
||||
for (Map<String, Object> test : testList) {
|
||||
datasetUid =
|
||||
this.insertTrainTestData(test, addReq, testIdx, datasetUid, "test"); // test 데이터 insert
|
||||
testIdx++;
|
||||
}
|
||||
|
||||
datasetCoreService.updateDatasetUploadStatus(datasetUid);
|
||||
return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다.");
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public ResponseObj insertDataset(@Valid AddReq addReq) {
|
||||
|
||||
@@ -218,6 +183,11 @@ public class DatasetService {
|
||||
// 압축 해제
|
||||
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());
|
||||
|
||||
// 압축 해제한 폴더 하위에 train,val,test 폴더 모두 존재하는지 확인
|
||||
validateTrainValTestDirs(addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
|
||||
|
||||
// TODO : 압축 해제한 폴더의 갯수 맞는지 log 찍기
|
||||
|
||||
// 해제한 폴더 읽어서 데이터 저장
|
||||
List<Map<String, Object>> list =
|
||||
getUnzipDatasetFiles(
|
||||
@@ -367,7 +337,10 @@ public class DatasetService {
|
||||
Path dir = root.resolve(dirName);
|
||||
|
||||
if (!Files.isDirectory(dir)) {
|
||||
throw new IllegalStateException("폴더가 존재하지 않습니다 : " + dir);
|
||||
throw new CustomApiException(
|
||||
ApiResponseCode.NOT_FOUND_DATA.getId(),
|
||||
HttpStatus.CONFLICT,
|
||||
"폴더가 존재하지 않습니다. 업로드 된 파일을 확인하세요. : " + dir);
|
||||
}
|
||||
|
||||
try (Stream<Path> stream = Files.list(dir)) {
|
||||
@@ -421,62 +394,6 @@ public class DatasetService {
|
||||
return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
private List<Map<String, Object>> getUnzipDatasetFilesTo86(String unzipRootPath, String subDir) {
|
||||
|
||||
// String root = Paths.get(unzipRootPath)
|
||||
// .resolve(subDir)
|
||||
// .toString();
|
||||
//
|
||||
String root = normalizeLinuxPath(unzipRootPath + "/" + subDir);
|
||||
|
||||
Map<String, Map<String, Object>> grouped = new HashMap<>();
|
||||
|
||||
for (String dirName : LABEL_DIRS) {
|
||||
|
||||
String remoteDir = root + "/" + dirName;
|
||||
|
||||
// 1. 86 서버에서 해당 디렉토리의 파일 목록 조회
|
||||
List<String> files = listFilesOn86Server(remoteDir);
|
||||
|
||||
if (files.isEmpty()) {
|
||||
throw new IllegalStateException("폴더가 존재하지 않거나 파일이 없습니다 : " + remoteDir);
|
||||
}
|
||||
|
||||
for (String fullPath : files) {
|
||||
|
||||
String fileName = Paths.get(fullPath).getFileName().toString();
|
||||
String baseName = getBaseName(fileName);
|
||||
|
||||
Map<String, Object> data = grouped.computeIfAbsent(baseName, k -> new HashMap<>());
|
||||
|
||||
data.put("baseName", baseName);
|
||||
|
||||
if ("label-json".equals(dirName)) {
|
||||
|
||||
// 2. json 내용도 86 서버에서 읽어서 가져와야 함
|
||||
String json = readRemoteFileAsString(fullPath);
|
||||
|
||||
data.put("label-json", parseJson(json));
|
||||
data.put("geojson_path", fullPath);
|
||||
|
||||
} else {
|
||||
|
||||
data.put(dirName, fullPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrayList<>(grouped.values());
|
||||
}
|
||||
|
||||
private List<String> listFilesOn86Server(String remoteDir) {
|
||||
|
||||
String command = "find " + escape(remoteDir) + " -maxdepth 1 -type f";
|
||||
|
||||
return FIleChecker.execCommandAndReadLines(command);
|
||||
}
|
||||
|
||||
private String readRemoteFileAsString(String remoteFilePath) {
|
||||
|
||||
String command = "cat " + escape(remoteFilePath);
|
||||
@@ -528,4 +445,32 @@ public class DatasetService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/** unzipRootDir: 압축 해제된 폴더 경로 (ex: /data/xxx/myzipname) */
|
||||
public static void validateTrainValTestDirs(String unzipRootDir) {
|
||||
Path root = Paths.get(unzipRootDir);
|
||||
|
||||
// 루트 폴더 자체 존재 확인
|
||||
if (!Files.exists(root) || !Files.isDirectory(root)) {
|
||||
throw new CustomApiException(
|
||||
ApiResponseCode.NOT_FOUND_DATA.getId(),
|
||||
HttpStatus.CONFLICT,
|
||||
"압축 해제 폴더가 존재하지 않습니다: " + unzipRootDir);
|
||||
}
|
||||
|
||||
// 필요한 폴더들 존재/디렉토리 여부 확인
|
||||
List<String> missing =
|
||||
REQUIRED_DIRS.stream()
|
||||
.filter(d -> !Files.isDirectory(root.resolve(d)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (!missing.isEmpty()) {
|
||||
throw new CustomApiException(
|
||||
ApiResponseCode.NOT_FOUND_DATA.getId(),
|
||||
HttpStatus.CONFLICT,
|
||||
"데이터 폴더 구조가 올바르지 않습니다. 누락된 폴더: "
|
||||
+ String.join(", ", missing)
|
||||
+ " (필수: train, val, test)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,10 @@ public class ModelTrainMngCoreService {
|
||||
ModelConfigEntity entity = new ModelConfigEntity();
|
||||
modelMasterEntity.setId(modelId);
|
||||
entity.setModel(modelMasterEntity);
|
||||
entity.setEpochCount(req.getEpochCnt());
|
||||
entity.setEpochCount(
|
||||
req.getEpochCnt() < 10
|
||||
? 10
|
||||
: req.getEpochCnt()); // 에폭이 10 이하이면 10으로 고정하기. 10 이상 에폭으로 해야 best 에폭 파일이 생성되어 내려옴
|
||||
entity.setTrainPercent(req.getTrainingCnt());
|
||||
entity.setValidationPercent(req.getValidationCnt());
|
||||
entity.setTestPercent(req.getTestCnt());
|
||||
|
||||
@@ -70,7 +70,7 @@ public class ModelTrainMetricsJobService {
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
|
||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||
long iteration = Long.parseLong(record.get("Iteration"));
|
||||
double Loss = Double.parseDouble(record.get("Loss"));
|
||||
double LR = Double.parseDouble(record.get("LR"));
|
||||
|
||||
Reference in New Issue
Block a user