diff --git a/src/main/java/com/kamco/cd/training/common/utils/FIleChecker.java b/src/main/java/com/kamco/cd/training/common/utils/FIleChecker.java index b9ba07b..8ec3be2 100644 --- a/src/main/java/com/kamco/cd/training/common/utils/FIleChecker.java +++ b/src/main/java/com/kamco/cd/training/common/utils/FIleChecker.java @@ -6,6 +6,8 @@ import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelSftp; import com.jcraft.jsch.JSch; import com.jcraft.jsch.Session; +import com.kamco.cd.training.common.exception.CustomApiException; +import com.kamco.cd.training.config.api.ApiResponseDto.ApiResponseCode; import io.swagger.v3.oas.annotations.media.Schema; import java.io.BufferedReader; import java.io.File; @@ -39,6 +41,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FilenameUtils; import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.gce.geotiff.GeoTiffReader; +import org.springframework.http.HttpStatus; import org.springframework.util.FileSystemUtils; import org.springframework.web.multipart.MultipartFile; @@ -757,6 +760,11 @@ public class FIleChecker { zipEntry = zis.getNextEntry(); } zis.closeEntry(); + } catch (IOException e) { + throw new CustomApiException( + ApiResponseCode.INTERNAL_SERVER_ERROR.getId(), + HttpStatus.INTERNAL_SERVER_ERROR, + "압축 해제 중 오류가 발생했습니다: " + e.getMessage()); } } diff --git a/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java b/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java index 04f08c6..d9d56f2 100644 --- a/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java +++ b/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java @@ -26,10 +26,12 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; import java.util.stream.Stream; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -56,6 +58,7 @@ public class DatasetService { private String datasetDir; private static final List LABEL_DIRS = List.of("label-json", "label", "input1", "input2"); + private static final List REQUIRED_DIRS = Arrays.asList("train", "val", "test"); /** * 데이터셋 목록 조회 @@ -164,44 +167,6 @@ public class DatasetService { } } - @Deprecated - @Transactional - public ResponseObj insertDatasetTo86(@Valid AddReq addReq) { - - Long datasetUid = null; // master id 값, 등록하면서 가져올 예정 - - // 압축 해제 - FIleChecker.unzipOn86Server( - addReq.getFilePath() + addReq.getFileName(), - addReq.getFilePath() + addReq.getFileName().replace(".zip", "")); - - // 해제한 폴더 읽어서 데이터 저장 - List> list = - getUnzipDatasetFilesTo86( - addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "train"); - - int idx = 0; - for (Map map : list) { - datasetUid = - this.insertTrainTestData(map, addReq, idx, datasetUid, "train"); // train 데이터 insert - idx++; - } - - List> testList = - getUnzipDatasetFilesTo86( - addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test"); - - int testIdx = 0; - for (Map test : testList) { - datasetUid = - this.insertTrainTestData(test, addReq, testIdx, datasetUid, "test"); // test 데이터 insert - testIdx++; - } - - datasetCoreService.updateDatasetUploadStatus(datasetUid); - return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다."); - } - @Transactional public ResponseObj insertDataset(@Valid AddReq addReq) { @@ -218,6 +183,11 @@ public class DatasetService { // 압축 해제 FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath()); + // 압축 해제한 폴더 하위에 train,val,test 폴더 모두 존재하는지 확인 + validateTrainValTestDirs(addReq.getFilePath() + addReq.getFileName().replace(".zip", "")); + + // TODO : 압축 해제한 폴더의 갯수 맞는지 log 찍기 + // 해제한 폴더 읽어서 데이터 저장 List> list = getUnzipDatasetFiles( @@ -367,7 +337,10 @@ public class DatasetService { Path dir = root.resolve(dirName); if (!Files.isDirectory(dir)) { - throw new IllegalStateException("폴더가 존재하지 않습니다 : " + dir); + throw new CustomApiException( + ApiResponseCode.NOT_FOUND_DATA.getId(), + HttpStatus.CONFLICT, + "폴더가 존재하지 않습니다. 업로드 된 파일을 확인하세요. : " + dir); } try (Stream stream = Files.list(dir)) { @@ -421,62 +394,6 @@ public class DatasetService { return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType); } - @Deprecated - private List> getUnzipDatasetFilesTo86(String unzipRootPath, String subDir) { - - // String root = Paths.get(unzipRootPath) - // .resolve(subDir) - // .toString(); - // - String root = normalizeLinuxPath(unzipRootPath + "/" + subDir); - - Map> grouped = new HashMap<>(); - - for (String dirName : LABEL_DIRS) { - - String remoteDir = root + "/" + dirName; - - // 1. 86 서버에서 해당 디렉토리의 파일 목록 조회 - List files = listFilesOn86Server(remoteDir); - - if (files.isEmpty()) { - throw new IllegalStateException("폴더가 존재하지 않거나 파일이 없습니다 : " + remoteDir); - } - - for (String fullPath : files) { - - String fileName = Paths.get(fullPath).getFileName().toString(); - String baseName = getBaseName(fileName); - - Map data = grouped.computeIfAbsent(baseName, k -> new HashMap<>()); - - data.put("baseName", baseName); - - if ("label-json".equals(dirName)) { - - // 2. json 내용도 86 서버에서 읽어서 가져와야 함 - String json = readRemoteFileAsString(fullPath); - - data.put("label-json", parseJson(json)); - data.put("geojson_path", fullPath); - - } else { - - data.put(dirName, fullPath); - } - } - } - - return new ArrayList<>(grouped.values()); - } - - private List listFilesOn86Server(String remoteDir) { - - String command = "find " + escape(remoteDir) + " -maxdepth 1 -type f"; - - return FIleChecker.execCommandAndReadLines(command); - } - private String readRemoteFileAsString(String remoteFilePath) { String command = "cat " + escape(remoteFilePath); @@ -528,4 +445,32 @@ public class DatasetService { throw new RuntimeException(e); } } + + /** unzipRootDir: 압축 해제된 폴더 경로 (ex: /data/xxx/myzipname) */ + public static void validateTrainValTestDirs(String unzipRootDir) { + Path root = Paths.get(unzipRootDir); + + // 루트 폴더 자체 존재 확인 + if (!Files.exists(root) || !Files.isDirectory(root)) { + throw new CustomApiException( + ApiResponseCode.NOT_FOUND_DATA.getId(), + HttpStatus.CONFLICT, + "압축 해제 폴더가 존재하지 않습니다: " + unzipRootDir); + } + + // 필요한 폴더들 존재/디렉토리 여부 확인 + List missing = + REQUIRED_DIRS.stream() + .filter(d -> !Files.isDirectory(root.resolve(d))) + .collect(Collectors.toList()); + + if (!missing.isEmpty()) { + throw new CustomApiException( + ApiResponseCode.NOT_FOUND_DATA.getId(), + HttpStatus.CONFLICT, + "데이터 폴더 구조가 올바르지 않습니다. 누락된 폴더: " + + String.join(", ", missing) + + " (필수: train, val, test)"); + } + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index f9395a7..25f7906 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -198,7 +198,10 @@ public class ModelTrainMngCoreService { ModelConfigEntity entity = new ModelConfigEntity(); modelMasterEntity.setId(modelId); entity.setModel(modelMasterEntity); - entity.setEpochCount(req.getEpochCnt()); + entity.setEpochCount( + req.getEpochCnt() < 10 + ? 10 + : req.getEpochCnt()); // 에폭이 10 이하이면 10으로 고정하기. 10 이상 에폭으로 해야 best 에폭 파일이 생성되어 내려옴 entity.setTrainPercent(req.getTrainingCnt()); entity.setValidationPercent(req.getValidationCnt()); entity.setTestPercent(req.getTestCnt()); diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index 4487973..933b9f1 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -70,7 +70,7 @@ public class ModelTrainMetricsJobService { for (CSVRecord record : parser) { - int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기 + int epoch = Integer.parseInt(record.get("Epoch")); long iteration = Long.parseLong(record.get("Iteration")); double Loss = Double.parseDouble(record.get("Loss")); double LR = Double.parseDouble(record.get("LR"));