diff --git a/src/main/java/com/kamco/cd/training/config/AsyncConfig.java b/src/main/java/com/kamco/cd/training/config/AsyncConfig.java index 197a37c..a75460e 100644 --- a/src/main/java/com/kamco/cd/training/config/AsyncConfig.java +++ b/src/main/java/com/kamco/cd/training/config/AsyncConfig.java @@ -20,4 +20,18 @@ public class AsyncConfig { executor.initialize(); return executor; } + + @Bean("datasetExecutor") + public Executor datasetExecutor() { + + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + + executor.setCorePoolSize(5); + executor.setMaxPoolSize(10); + executor.setQueueCapacity(100); + executor.setThreadNamePrefix("dataset-"); + + executor.initialize(); + return executor; + } } diff --git a/src/main/java/com/kamco/cd/training/dataset/DatasetApiController.java b/src/main/java/com/kamco/cd/training/dataset/DatasetApiController.java index c0db7b3..d5d7260 100644 --- a/src/main/java/com/kamco/cd/training/dataset/DatasetApiController.java +++ b/src/main/java/com/kamco/cd/training/dataset/DatasetApiController.java @@ -1,12 +1,12 @@ package com.kamco.cd.training.dataset; import com.kamco.cd.training.config.api.ApiResponseDto; -import com.kamco.cd.training.config.api.ApiResponseDto.ResponseObj; import com.kamco.cd.training.dataset.dto.DatasetDto; import com.kamco.cd.training.dataset.dto.DatasetDto.AddDeliveriesReq; import com.kamco.cd.training.dataset.dto.DatasetObjDto; import com.kamco.cd.training.dataset.dto.DatasetObjDto.DatasetClass; import com.kamco.cd.training.dataset.dto.DatasetObjDto.DatasetStorage; +import com.kamco.cd.training.dataset.service.DatasetAsyncService; import com.kamco.cd.training.dataset.service.DatasetService; import com.kamco.cd.training.model.dto.FileDto.FoldersDto; import com.kamco.cd.training.model.dto.FileDto.SrchFoldersDto; @@ -37,6 +37,7 @@ import org.springframework.web.bind.annotation.*; public class DatasetApiController { private final DatasetService datasetService; + private final DatasetAsyncService datasetAsyncService; @Operation(summary = "학습데이터 관리 목록 조회", description = "학습데이터 목록을 조회합니다.") @ApiResponses( @@ -262,7 +263,7 @@ public class DatasetApiController { content = @Content( mediaType = "application/json", - schema = @Schema(implementation = Page.class))), + schema = @Schema(implementation = FoldersDto.class))), @ApiResponse(responseCode = "404", description = "조회 오류", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @@ -281,12 +282,13 @@ public class DatasetApiController { content = @Content( mediaType = "application/json", - schema = @Schema(implementation = Page.class))), + schema = @Schema(implementation = String.class))), @ApiResponse(responseCode = "404", description = "조회 오류", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @PostMapping("/deliveries") - public ApiResponseDto insertDeliveriesDataset(@RequestBody AddDeliveriesReq req) { - return ApiResponseDto.createOK(datasetService.insertDeliveriesDataset(req)); + public ApiResponseDto insertDeliveriesDataset(@RequestBody AddDeliveriesReq req) { + datasetAsyncService.insertDeliveriesDatasetAsync(req); + return ApiResponseDto.createOK("ok"); } } diff --git a/src/main/java/com/kamco/cd/training/dataset/service/DatasetAsyncService.java b/src/main/java/com/kamco/cd/training/dataset/service/DatasetAsyncService.java new file mode 100644 index 0000000..379f744 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/dataset/service/DatasetAsyncService.java @@ -0,0 +1,134 @@ +package com.kamco.cd.training.dataset.service; + +import com.kamco.cd.training.common.enums.LearnDataRegister; +import com.kamco.cd.training.dataset.dto.DatasetDto.AddDeliveriesReq; +import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; +import com.kamco.cd.training.postgres.core.DatasetCoreService; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +@Service +@Log4j2 +@RequiredArgsConstructor +public class DatasetAsyncService { + + private final DatasetService datasetService; + private final DatasetCoreService datasetCoreService; + + private static final String LOG_PREFIX = "[납품 데이터셋]"; + + /** + * 납품 데이터셋 등록 비동기 업로드 처리 1) 데이터셋 구조/파일 검증 2) UID 생성 및 마스터 데이터 저장 3) 상태를 UPLOADING으로 변경 4) 실제 + * 데이터(train/val/test) 등록 5) 완료 시 COMPLETED 상태로 변경 6) 실패 시 상태를 UPLOAD_FAILED로 변경 후 데이터 정리(삭제) + * + * @param req + */ + @Async("datasetExecutor") + public void insertDeliveriesDatasetAsync(AddDeliveriesReq req) { + + long startTime = System.currentTimeMillis(); + + log.info("{} 업로드 시작 ==========", LOG_PREFIX); + log.info( + "{} filePath={}, targetYyyy={}, compareYyyy={}, roundNo={}", + LOG_PREFIX, + req.getFilePath(), + req.getTargetYyyy(), + req.getCompareYyyy(), + req.getRoundNo()); + + Long datasetUid = null; + + try { + + // ===== 1. 폴더/파일 검증 ===== + long validateStart = System.currentTimeMillis(); + + // 폴더 구조 검증 + DatasetService.validateTrainValTestDirs(req.getFilePath()); + // 파일 개수 검증 + DatasetService.validateDirFileCount(req.getFilePath()); + + log.info("{} 데이터셋 검증 완료. ({} ms)", LOG_PREFIX, System.currentTimeMillis() - validateStart); + + // ===== 2. UID 생성 ===== + String uid = UUID.randomUUID().toString().replace("-", "").toUpperCase(); + log.info("{} 생성된 UID: {}", LOG_PREFIX, uid); + + // ===== 3. 마스터 데이터 생성 ===== + String title = req.getTitle(); + + if (title == null || title.isBlank()) { + Integer compareYyyy = req.getCompareYyyy(); + Integer targetYyyy = req.getTargetYyyy(); + + if (compareYyyy != null && targetYyyy != null) { + title = compareYyyy + "-" + targetYyyy; + } else { + title = null; + } + } + + DatasetMngRegDto datasetMngRegDto = new DatasetMngRegDto(); + datasetMngRegDto.setUid(uid); + datasetMngRegDto.setDataType("DELIVER"); + datasetMngRegDto.setCompareYyyy(req.getCompareYyyy() == null ? 0 : req.getCompareYyyy()); + datasetMngRegDto.setTargetYyyy(req.getTargetYyyy() == null ? 0 : req.getTargetYyyy()); + datasetMngRegDto.setRoundNo(req.getRoundNo()); + datasetMngRegDto.setTitle(title); + datasetMngRegDto.setMemo(req.getMemo()); + datasetMngRegDto.setDatasetPath(req.getFilePath()); + + // 마스터 저장 + datasetUid = datasetCoreService.insertDatasetMngData(datasetMngRegDto); + + log.info("{} 마스터 저장 완료. datasetUid={}", LOG_PREFIX, datasetUid); + + // ===== 4. 상태 변경 (업로드중) ===== + datasetCoreService.updateDatasetUploadStatus(datasetUid, LearnDataRegister.UPLOADING); + log.info("{} 상태 변경 → UPLOADING. datasetUid={}", LOG_PREFIX, datasetUid); + + // ===== 5. 데이터 등록 ===== + long insertStart = System.currentTimeMillis(); + + // 납품 데이터 obj 등록 + datasetService.insertDeliveriesDataset(req, datasetUid); + + log.info( + "{} 데이터 등록 완료. datasetUid={}, 소요시간={} ms", + LOG_PREFIX, + datasetUid, + System.currentTimeMillis() - insertStart); + + // ===== 6. 상태 변경 (완료) ===== + datasetCoreService.updateDatasetUploadStatus(datasetUid, LearnDataRegister.COMPLETED); + log.info("{} 상태 변경 → COMPLETED. datasetUid={}", LOG_PREFIX, datasetUid); + + log.info( + "{} 업로드 완료. 총 소요시간={} ms ==========", LOG_PREFIX, System.currentTimeMillis() - startTime); + + } catch (Exception e) { + + log.error( + "{} 업로드 실패. datasetUid={}, filePath={}", LOG_PREFIX, datasetUid, req.getFilePath(), e); + + if (datasetUid != null) { + try { + // ===== 실패 처리 ===== + datasetCoreService.updateDatasetUploadStatus(datasetUid, LearnDataRegister.UPLOAD_FAILED); + log.error("{} 상태 변경 → 업로드 실패. datasetUid={}", LOG_PREFIX, datasetUid); + + // 실패 시 데이터 정리 + datasetCoreService.deleteAllDatasetObj(datasetUid); + log.error("{} 데이터 정리 완료. datasetUid={}", LOG_PREFIX, datasetUid); + + } catch (Exception ex) { + log.error("{} 실패 후 정리 작업 중 오류. datasetUid={}", LOG_PREFIX, datasetUid, ex); + } + } + } + } +} diff --git a/src/main/java/com/kamco/cd/training/dataset/service/DatasetBatchService.java b/src/main/java/com/kamco/cd/training/dataset/service/DatasetBatchService.java new file mode 100644 index 0000000..4c396fd --- /dev/null +++ b/src/main/java/com/kamco/cd/training/dataset/service/DatasetBatchService.java @@ -0,0 +1,152 @@ +package com.kamco.cd.training.dataset.service; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.kamco.cd.training.dataset.dto.DatasetObjDto.DatasetObjRegDto; +import com.kamco.cd.training.postgres.core.DatasetCoreService; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@Log4j2 +@RequiredArgsConstructor +public class DatasetBatchService { + + private final DatasetCoreService datasetCoreService; + private final ObjectMapper mapper; + + /** + * 배치 단위 데이터 저장 + * + *

- 전달받은 데이터 목록을 순회하며 개별 insert 처리 - batch 단위로 트랜잭션 관리 + */ + @Transactional + public void saveBatch(List> batch, Long datasetUid, String type) { + for (Map map : batch) { + insertTrainTestData(map, datasetUid, type); + } + } + + /** + * 단일 데이터 처리 및 insert DTO 생성 + * + *

처리 흐름: 1) 경로/JSON 데이터 추출 2) 파일명에서 연도 및 도엽번호 파싱 3) label JSON → feature 단위 분리 4) feature별 DTO + * 생성 후 DB insert + */ + private void insertTrainTestData(Map map, Long datasetUid, String subDir) { + + String comparePath = (String) map.get("input1"); + String targetPath = (String) map.get("input2"); + String labelPath = (String) map.get("label"); + String geojsonPath = (String) map.get("geojson_path"); + Object labelJson = map.get("label-json"); + + // JSON 파싱 + JsonNode json = parseJson(labelJson); + + // 파일명 파싱 + String fileName = Paths.get(comparePath).getFileName().toString(); + String[] fileNameStr = fileName.split("_"); + + if (fileNameStr.length < 4) { + log.error("파일명 파싱 실패: {}", fileName); + throw new IllegalArgumentException("잘못된 파일명 형식: " + fileName); + } + + int compareYyyy = parseInt(fileNameStr[1], "compareYyyy", fileName); + int targetYyyy = parseInt(fileNameStr[2], "targetYyyy", fileName); + String mapSheetNum = fileNameStr[3]; + + // JSON 유효성 체크 + JsonNode featuresNode = json.path("features"); + + if (featuresNode.isMissingNode() || !featuresNode.isArray() || featuresNode.isEmpty()) { + return; // skip + } + + ObjectNode base = mapper.createObjectNode(); + base.put("type", "FeatureCollection"); + + for (JsonNode feature : featuresNode) { + + JsonNode prop = feature.path("properties"); + + String compareClassCd = prop.path("before").asText(null); + String targetClassCd = prop.path("after").asText(null); + + ArrayNode arr = mapper.createArrayNode(); + arr.add(feature); + + ObjectNode root = base.deepCopy(); + root.set("features", arr); + + DatasetObjRegDto objRegDto = + DatasetObjRegDto.builder() + .datasetUid(datasetUid) + .compareYyyy(compareYyyy) + .compareClassCd(compareClassCd) + .targetYyyy(targetYyyy) + .targetClassCd(targetClassCd) + .comparePath(comparePath) + .targetPath(targetPath) + .labelPath(labelPath) + .mapSheetNum(mapSheetNum) + .geojson(root) + .geojsonPath(geojsonPath) + .fileName(fileName) + .build(); + + // 데이터 타입별 insert + insertByType(subDir, objRegDto); + } + } + + /** 데이터 타입별 insert 처리 - type 값에 따라 대상 테이블 분기 - 잘못된 타입 입력 시 예외 발생 */ + private void insertByType(String type, DatasetObjRegDto dto) { + + switch (type) { + case "train" -> datasetCoreService.insertDatasetObj(dto); + case "val" -> datasetCoreService.insertDatasetValObj(dto); + case "test" -> datasetCoreService.insertDatasetTestObj(dto); + default -> throw new IllegalArgumentException("잘못된 타입: " + type); + } + } + + /** + * label_json → JsonNode 변환 + * + *

- JsonNode면 그대로 사용 - 문자열이면 파싱 수행 - 실패 시 로그 후 예외 발생 + */ + private JsonNode parseJson(Object labelJson) { + try { + if (labelJson instanceof JsonNode jn) { + return jn; + } + return mapper.readTree(labelJson.toString()); + } catch (Exception e) { + log.error("label_json parse error: {}", labelJson, e); + throw new RuntimeException("label_json parse error", e); + } + } + + /** + * 문자열 → 정수 변환 + * + *

- 파싱 실패 시 어떤 필드/파일에서 발생했는지 로그 기록 - 잘못된 데이터는 즉시 예외 처리 + */ + private int parseInt(String value, String field, String fileName) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + log.error("{} 파싱 실패. fileName={}, value={}", field, fileName, value); + throw new IllegalArgumentException(field + " 파싱 실패: " + fileName); + } + } +} diff --git a/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java b/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java index 8370084..af6dd05 100644 --- a/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java +++ b/src/main/java/com/kamco/cd/training/dataset/service/DatasetService.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.kamco.cd.training.common.enums.LearnDataRegister; import com.kamco.cd.training.common.enums.LearnDataType; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.service.FormatStorage; @@ -22,8 +23,6 @@ import com.kamco.cd.training.dataset.dto.DatasetObjDto.SearchReq; import com.kamco.cd.training.model.dto.FileDto.FoldersDto; import com.kamco.cd.training.model.dto.FileDto.SrchFoldersDto; import com.kamco.cd.training.postgres.core.DatasetCoreService; -import jakarta.persistence.EntityManager; -import jakarta.persistence.PersistenceContext; import jakarta.validation.Valid; import java.io.File; import java.io.IOException; @@ -45,7 +44,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; import org.springframework.data.domain.Page; @@ -62,11 +60,7 @@ import org.springframework.transaction.annotation.Transactional; public class DatasetService { private final DatasetCoreService datasetCoreService; - @PersistenceContext private EntityManager em; - private final ObjectMapper mapper = new ObjectMapper(); - - @Value("${file.dataset-dir}") - private String datasetDir; + private final DatasetBatchService datasetBatchService; private static final List LABEL_DIRS = List.of("label-json", "label", "input1", "input2"); private static final List REQUIRED_DIRS = Arrays.asList("train", "val", "test"); @@ -143,7 +137,6 @@ public class DatasetService { return datasetCoreService.searchDatasetObjectList(searchReq); } - @Transactional public UUID deleteDatasetObjByUuid(UUID uuid) { return datasetCoreService.deleteDatasetObjByUuid(uuid); } @@ -241,7 +234,7 @@ public class DatasetService { return new ResponseObj(ApiResponseCode.INTERNAL_SERVER_ERROR, e.getMessage()); } - datasetCoreService.updateDatasetUploadStatus(datasetUid); + datasetCoreService.updateDatasetUploadStatus(datasetUid, LearnDataRegister.COMPLETED); return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다."); } @@ -616,146 +609,65 @@ public class DatasetService { * 납품 데이터 등록 * * @param req 폴더경로, 메모 - * @return 성공/실패 여부 + * @return 성공/실패 여부0 */ - public ResponseObj insertDeliveriesDataset(AddDeliveriesReq req) { + public void insertDeliveriesDataset(AddDeliveriesReq req, Long datasetUid) { long startTime = System.currentTimeMillis(); - log.info("========== 납품 데이터셋 업로드 시작 =========="); - log.info("filePath: {}", req.getFilePath()); - - DatasetMngRegDto datasetMngRegDto = new DatasetMngRegDto(); - - String uid = UUID.randomUUID().toString().replace("-", "").toUpperCase(); - - datasetMngRegDto.setUid(uid); - datasetMngRegDto.setDataType("DELIVER"); - datasetMngRegDto.setCompareYyyy(req.getCompareYyyy() == null ? 0 : req.getCompareYyyy()); - datasetMngRegDto.setTargetYyyy(req.getTargetYyyy() == null ? 0 : req.getTargetYyyy()); - datasetMngRegDto.setRoundNo(req.getRoundNo() == null ? null : req.getRoundNo()); - datasetMngRegDto.setTitle(req.getTitle() == null ? null : req.getTitle()); - datasetMngRegDto.setMemo(req.getMemo() == null ? null : req.getMemo()); - datasetMngRegDto.setDatasetPath(req.getFilePath()); - - // 마스터 저장 (트랜잭션 내부) - Long datasetUid = datasetCoreService.insertDatasetMngData(datasetMngRegDto); - - log.info("납품 Dataset 마스터 저장 완료. datasetUid: {}", datasetUid); - - // 검증 - validateTrainValTestDirs(req.getFilePath()); - validateDirFileCount(req.getFilePath()); - // 처리 processType(req.getFilePath(), datasetUid, "train"); processType(req.getFilePath(), datasetUid, "val"); processType(req.getFilePath(), datasetUid, "test"); - datasetCoreService.updateDatasetUploadStatus(datasetUid); - log.info("========== 전체 완료. 총 소요시간: {} ms ==========", System.currentTimeMillis() - startTime); - - return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다."); } + /** + * 납품 데이터 등록 처리 + * + * @param path + * @param datasetUid + * @param type + */ private void processType(String path, Long datasetUid, String type) { - long start = System.currentTimeMillis(); - log.info("[{}] 데이터 처리 시작", type.toUpperCase()); + log.info("[납품 데이터 등록 처리][{}] 시작", type.toUpperCase()); List> list = getUnzipDatasetFiles(path, type); - log.info("[{}] 파일 개수: {}", type.toUpperCase(), list.size()); + int batchSize = 1000; + int total = list.size(); + int processed = 0; - int count = 0; + for (int i = 0; i < total; i += batchSize) { - for (Map map : list) { - insertTrainTestData(map, datasetUid, type); - count++; + List> batch = list.subList(i, Math.min(i + batchSize, total)); - if (count % 1000 == 0 || count == list.size()) { - log.info("[{}] 진행건수: {}", type.toUpperCase(), count); + try { + log.info("[납품 데이터 등록 처리][{}] batch 시작: {} ~ {}", type, i, i + batch.size()); + + datasetBatchService.saveBatch(batch, datasetUid, type); + + processed += batch.size(); + } catch (Exception e) { + log.error("batch 실패 row 데이터: {}", batch); + log.error( + "[납품 데이터 등록 처리][{}] batch 실패. range: {} ~ {}, datasetUid={}", + type, + i, + i + batch.size(), + datasetUid, + e); + throw e; } } log.info( - "[{}] 완료. 총 {}건, 소요시간: {} ms", - type.toUpperCase(), - count, + "[납품 데이터 등록 처리][{}] 완료. 총 {}건, 소요시간: {} ms", + type, + total, System.currentTimeMillis() - start); } - - @Transactional - public void insertTrainTestData(Map map, Long datasetUid, String subDir) { - - String comparePath = (String) map.get("input1"); - String targetPath = (String) map.get("input2"); - String labelPath = (String) map.get("label"); - String geojsonPath = (String) map.get("geojson_path"); - Object labelJson = map.get("label-json"); - - JsonNode json; - - try { - if (labelJson instanceof JsonNode jn) { - json = jn; - } else { - json = mapper.readTree(labelJson.toString()); - } - } catch (Exception e) { - throw new RuntimeException("label_json parse error", e); - } - - String fileName = Paths.get(comparePath).getFileName().toString(); - String[] fileNameStr = fileName.split("_"); - - String compareYyyy = fileNameStr[1]; - String targetYyyy = fileNameStr[2]; - String mapSheetNum = fileNameStr[3]; - - if (json != null && json.path("features") != null && !json.path("features").isEmpty()) { - - for (JsonNode feature : json.path("features")) { - - JsonNode prop = feature.path("properties"); - - String compareClassCd = prop.path("before").asText(null); - String targetClassCd = prop.path("after").asText(null); - - ObjectNode root = mapper.createObjectNode(); - root.put("type", "FeatureCollection"); - - ArrayNode features = mapper.createArrayNode(); - features.add(feature); - root.set("features", features); - - DatasetObjRegDto objRegDto = - DatasetObjRegDto.builder() - .datasetUid(datasetUid) - .compareYyyy(Integer.parseInt(compareYyyy)) - .compareClassCd(compareClassCd) - .targetYyyy(Integer.parseInt(targetYyyy)) - .targetClassCd(targetClassCd) - .comparePath(comparePath) - .targetPath(targetPath) - .labelPath(labelPath) - .mapSheetNum(mapSheetNum) - .geojson(root) - .geojsonPath(geojsonPath) - .fileName(fileName) - .build(); - - // insert - if (subDir.equals("train")) { - datasetCoreService.insertDatasetObj(objRegDto); - } else if (subDir.equals("val")) { - datasetCoreService.insertDatasetValObj(objRegDto); - } else { - datasetCoreService.insertDatasetTestObj(objRegDto); - } - } - } - } } diff --git a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java index c01a0d4..330905e 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java @@ -330,8 +330,27 @@ public class ModelTrainDetailApiController { } @Operation( - summary = "모델관리 > 모델 상세 > best epoch 제외 삭제", - description = "best epoch 제외 pth 파일 삭제 API") + summary = "모델관리 > 모델 상세 > best epoch 제외 삭제 될 파일 미리보기", + description = "best epoch 제외 삭제 될 파일 미리보기 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "조회 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = CleanupResult.class))), + @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @GetMapping("/{uuid}/cleanup/preview") + public ApiResponseDto previewCleanup( + @Parameter(description = "모델 uuid") @PathVariable UUID uuid) { + return ApiResponseDto.ok(modelTrainDetailService.previewCleanup(uuid)); + } + + @Operation(summary = "모델관리 > 모델 상세 > best epoch 제외 삭제", description = "best epoch 제외 파일 삭제 API") @ApiResponses( value = { @ApiResponse( diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java index a344332..42bc485 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -346,5 +346,8 @@ public class ModelTrainMngDto { // 유지된 파일명 (best epoch 기준) private String keptFile; + + // 삭제 될 파일 + private List deleteTargets; } } diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java index ed4dbe0..96aa777 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.model.service; import com.kamco.cd.training.common.enums.ModelType; +import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet; @@ -22,10 +23,17 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto; import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import java.io.IOException; +import java.nio.file.AccessDeniedException; +import java.nio.file.FileVisitOption; +import java.nio.file.FileVisitResult; import java.nio.file.Files; +import java.nio.file.LinkOption; import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; import java.util.ArrayList; +import java.util.EnumSet; import java.util.List; import java.util.UUID; import java.util.stream.Stream; @@ -167,15 +175,16 @@ public class ModelTrainDetailService { } /** - * 베스트 에폭 제외 *.pth 파일 삭제 + * 삭제될 파일목록 및 유지될 파일 목록 * - * @param uuid 학습 모델 uuid + * @param uuid + * @return */ - public CleanupResult cleanup(UUID uuid) { + public CleanupResult previewCleanup(UUID uuid) { CleanupResult result = new CleanupResult(); - // 학습 정보 조회 + // ===== 모델 조회 ===== ModelTrainMngDto.Basic model = modelTrainDetailCoreService.findByModelByUUID(uuid); if (model == null) { @@ -183,11 +192,98 @@ public class ModelTrainDetailService { "NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다. UUID: " + uuid); } - // 학습 결과 폴더 경로 - Path dir = Paths.get(responseDir, model.getUuid().toString()); + Path dir = Paths.get(responseDir, model.getUuid().toString()).toAbsolutePath().normalize(); + + if (!Files.exists(dir) || !Files.isDirectory(dir)) { + throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "디렉토리가 없습니다."); + } + + if (!Files.isReadable(dir)) { + throw new CustomApiException("FORBIDDEN", HttpStatus.FORBIDDEN, "디렉토리 읽기 권한이 없습니다."); + } + + try (Stream stream = Files.list(dir)) { + + List files = stream.toList(); + + if (files.isEmpty()) { + throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "파일이 없습니다."); + } + + // ===== keep 파일 찾기 ===== + Path keep = + files.stream() + .filter( + p -> { + String name = p.getFileName().toString(); + return name.endsWith(".zip") && name.contains(model.getUuid().toString()); + }) + .findFirst() + .orElseThrow( + () -> + new CustomApiException( + "NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "zip 파일이 없습니다.")); + + log.info("유지 파일: {}", keep.getFileName()); + + // ===== 결과 세팅 ===== + result.setTotalCount(files.size()); + result.setKeptFile(keep.getFileName().toString()); + + // ===== 삭제 대상 ===== + List deleteTargets = + files.stream() + .filter( + p -> !p.toAbsolutePath().normalize().equals(keep.toAbsolutePath().normalize())) + .map(p -> p.getFileName().toString()) + .toList(); + + result.setDeleteTargets(deleteTargets); + + log.info( + "previewCleanup 완료. total={}, deleteTargets={}", + result.getTotalCount(), + deleteTargets.size()); + + return result; + + } catch (IOException e) { + log.error("파일 목록 조회 실패: {}", dir, e); + throw new CustomApiException( + "INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "파일 목록 조회 실패"); + } + } + + public CleanupResult cleanup(UUID uuid) { + // ===== 모델 조회 ===== + ModelTrainMngDto.Basic model = modelTrainDetailCoreService.findByModelByUUID(uuid); + + if (model == null) { + throw new CustomApiException( + "NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다. UUID: " + uuid); + } + + if (!TrainStatusType.COMPLETED.getId().equals(model.getStep2Status())) { + throw new CustomApiException("CONFLICT", HttpStatus.CONFLICT, "테스트가 완료되지 않았습니다."); + } + + // ===== 경로 ===== + Path dir = Paths.get(responseDir, model.getUuid().toString()).toAbsolutePath().normalize(); + + return executeCleanup(model, dir); + } + + /** + * 베스트 에폭 제외 파일 삭제, 베스트 에폭 zip 파일만 남김 + * + * @param model model 정보 + * @param dir response 폴더 경로 + * @return 삭제 정보 + */ + public CleanupResult executeCleanup(ModelTrainMngDto.Basic model, Path dir) { + CleanupResult result = new CleanupResult(); if (!Files.exists(dir) || !Files.isDirectory(dir)) { - log.info("디렉토리가 없습니다.: {}", dir); throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "디렉토리가 없습니다."); } @@ -199,7 +295,6 @@ public class ModelTrainDetailService { throw new CustomApiException("FORBIDDEN", HttpStatus.FORBIDDEN, "디렉토리 삭제 권한이 없습니다."); } - // 저장된 best epoch int bestEpoch = model.getBestEpoch(); if (bestEpoch <= 0) { @@ -211,74 +306,74 @@ public class ModelTrainDetailService { try (Stream stream = Files.list(dir)) { - List pthFiles = - stream.filter(p -> p.getFileName().toString().endsWith(".pth")).toList(); + List files = stream.toList(); - if (pthFiles.isEmpty()) { - log.info("pth 파일이 없습니다."); - throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "pth 파일이 없습니다."); + if (files.isEmpty()) { + throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "파일이 없습니다."); } // ===== keep 파일 찾기 ===== Path keep = null; - // 우선순위 1: best_*_epoch_{bestEpoch}.pth - for (Path p : pthFiles) { + for (Path p : files) { String name = p.getFileName().toString(); - if (name.startsWith("best_") && name.contains("epoch_" + bestEpoch + ".pth")) { + if (name.endsWith(".zip") && name.contains(model.getUuid().toString())) { keep = p; break; } } - // 우선순위 2: epoch_{bestEpoch}.pth if (keep == null) { - for (Path p : pthFiles) { - if (p.getFileName().toString().equals("epoch_" + bestEpoch + ".pth")) { - keep = p; - break; - } - } + throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "zip 파일이 없습니다."); } - if (keep == null) { - log.info("bestEpoch에 해당하는 파일이 없습니다. epoch={}", bestEpoch); - throw new CustomApiException( - "NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "bestEpoch에 해당하는 파일이 없습니다. : " + bestEpoch); - } + log.info("유지 파일: {}", keep.getFileName()); - log.info("유지할 파일: {}", keep.getFileName()); - - // ===== 결과 세팅 ===== - result.setTotalCount(pthFiles.size()); + result.setTotalCount(files.size()); result.setKeptFile(keep.getFileName().toString()); int deletedCount = 0; List failed = new ArrayList<>(); - // ===== 삭제 처리 ===== - for (Path p : pthFiles) { + // ===== 삭제 ===== + for (Path p : files) { - if (!p.toAbsolutePath().normalize().equals(keep.toAbsolutePath().normalize())) { + if (p.equals(keep)) { + continue; + } - try { - boolean deleted = Files.deleteIfExists(p); + try { - if (deleted) { - deletedCount++; - log.info("삭제됨: {}", p.getFileName()); - } else { - log.info("이미 없음 (skip): {}", p.getFileName()); - } - - } catch (IOException e) { - failed.add(p.getFileName().toString()); - log.error("삭제 실패: {}", p.getFileName(), e); + // 심볼릭 링크 → 링크만 삭제 + if (Files.isSymbolicLink(p)) { + Files.deleteIfExists(p); + log.info("심볼릭 링크 삭제: {}", p.getFileName()); } + + // 디렉토리 → 재귀 삭제 + else if (Files.isDirectory(p, LinkOption.NOFOLLOW_LINKS)) { + log.info("디렉토리 재귀 삭제: {}", p.getFileName()); + deleteDirectory(p); + } + + // 일반 파일 + else { + Files.deleteIfExists(p); + log.info("파일 삭제: {}", p.getFileName()); + } + + deletedCount++; + + } catch (AccessDeniedException e) { + failed.add(p.getFileName().toString()); + log.error("권한 없음: {}", p.getFileName(), e); + + } catch (IOException e) { + failed.add(p.getFileName().toString()); + log.error("삭제 실패: {}", p.getFileName(), e); } } - // ===== 결과 저장 ===== result.setDeletedCount(deletedCount); result.setFailedCount(failed.size()); result.setFailedFiles(failed); @@ -297,4 +392,43 @@ public class ModelTrainDetailService { return result; } + + // 디렉토리 재귀 삭제 + private void deleteDirectory(Path dir) throws IOException { + + if (!Files.exists(dir, LinkOption.NOFOLLOW_LINKS)) { + return; + } + + // dir 자체가 심볼릭 링크면 링크만 삭제 + if (Files.isSymbolicLink(dir)) { + Files.delete(dir); + return; + } + + Files.walkFileTree( + dir, + EnumSet.noneOf(FileVisitOption.class), // NOFOLLOW_LINKS + Integer.MAX_VALUE, + new SimpleFileVisitor<>() { + + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + + Files.delete(file); // 링크면 링크만 삭제됨 + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path directory, IOException exc) + throws IOException { + + if (exc != null) throw exc; + + Files.delete(directory); + return FileVisitResult.CONTINUE; + } + }); + } } diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index e3890a1..dbd3613 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -3,6 +3,7 @@ package com.kamco.cd.training.model.service; import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.enums.HyperParamSelectType; import com.kamco.cd.training.common.enums.ModelType; +import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; @@ -14,10 +15,21 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq; import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.train.service.TrainJobService; +import java.io.IOException; +import java.nio.file.FileVisitOption; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.LinkOption; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.EnumSet; import java.util.List; import java.util.UUID; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; import org.springframework.data.domain.Page; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; @@ -32,6 +44,13 @@ public class ModelTrainMngService { private final ModelTrainMngCoreService modelTrainMngCoreService; private final HyperParamCoreService hyperParamCoreService; private final TrainJobService trainJobService; + private final ModelTrainDetailService modelTrainDetailService; + + @Value("${train.docker.basePath}") + private String basePath; + + @Value("${train.docker.responseDir}") + private String responseDir; /** * 모델학습 조회 @@ -46,11 +65,199 @@ public class ModelTrainMngService { /** * 모델학습 삭제 * - * @param uuid + *

순서: 1. tmp 구조 검증 (예외 발생 가능) 2. DB 삭제 (트랜잭션) 3. 파일 삭제 (실패해도 로그만) */ @Transactional public void deleteModelTrain(UUID uuid) { + + log.info("deleteModelTrain 시작. uuid={}", uuid); + + // ===== 1. 모델 조회 ===== + ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelByUuid(uuid); + + if (model == null) { + throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 없음"); + } + + // ===== 2. 경로 생성 ===== + Path tmpBase = Path.of(basePath, "tmp").toAbsolutePath().normalize(); + Path tmp = tmpBase.resolve(model.getRequestPath()).normalize(); + + Path responseBase = Paths.get(responseDir).toAbsolutePath().normalize(); + Path response = responseBase.resolve(model.getUuid().toString()).normalize(); + + // ===== 3. 경로 탈출 방지 ===== + if (!tmp.startsWith(tmpBase)) { + throw new CustomApiException("INVALID_PATH", HttpStatus.BAD_REQUEST, "잘못된 tmp 경로"); + } + + if (!response.startsWith(responseBase)) { + throw new CustomApiException("INVALID_PATH", HttpStatus.BAD_REQUEST, "잘못된 response 경로"); + } + + // ===== 4. 상태 로그 ===== + log.info( + "tmp 상태: exists={}, isDir={}, isSymlink={}", + Files.exists(tmp, LinkOption.NOFOLLOW_LINKS), + Files.isDirectory(tmp, LinkOption.NOFOLLOW_LINKS), + Files.isSymbolicLink(tmp)); + + log.info( + "response 상태: exists={}, isDir={}, isSymlink={}", + Files.exists(response, LinkOption.NOFOLLOW_LINKS), + Files.isDirectory(response, LinkOption.NOFOLLOW_LINKS), + Files.isSymbolicLink(response)); + + // ===== 5. tmp 구조 검증 ===== + validateTmpStructure(tmp); + + // ===== 6. DB 삭제 ===== modelTrainMngCoreService.deleteModel(uuid); + log.info("DB 삭제 완료. uuid={}", uuid); + + // ===== 7. tmp 삭제 ===== + log.info("tmp 삭제 시작: {}", tmp); + try { + deleteTmpDirectory(tmp); + log.info("tmp 삭제 완료: {}", tmp); + } catch (Exception e) { + log.error("tmp 삭제 실패 (DB는 이미 삭제됨): {}", tmp, e); + } + + // ===== 8. response 삭제 ===== + log.info("response 삭제 시작: {}", response); + try { + // 테스트 완료되었으면 베스트 에폭은 삭제안함 + if (TrainStatusType.COMPLETED.getId().equals(model.getStep2Status())) { + modelTrainDetailService.executeCleanup(model, response); + } else { + deleteResponseDirectory(response); + } + + log.info("response 삭제 완료: {}", response); + } catch (Exception e) { + log.error("response 삭제 실패 (DB는 이미 삭제됨): {}", response, e); + } + + log.info("deleteModelTrain 완료. uuid={}", uuid); + } + + /** tmp 디렉토리 삭제 */ + private void deleteTmpDirectory(Path dir) throws IOException { + + if (!Files.exists(dir, LinkOption.NOFOLLOW_LINKS)) { + log.warn("삭제 대상 없음: {}", dir); + return; + } + + Files.walkFileTree( + dir, + EnumSet.noneOf(FileVisitOption.class), + Integer.MAX_VALUE, + new SimpleFileVisitor<>() { + + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + + Files.delete(file); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path directory, IOException exc) + throws IOException { + + if (exc != null) { + throw exc; + } + + Files.delete(directory); + return FileVisitResult.CONTINUE; + } + }); + } + + /** response 디렉토리 삭제 */ + private void deleteResponseDirectory(Path dir) throws IOException { + + if (!Files.exists(dir, LinkOption.NOFOLLOW_LINKS)) { + log.warn("삭제 대상 없음: {}", dir); + return; + } + + Files.walkFileTree( + dir, + EnumSet.noneOf(FileVisitOption.class), + Integer.MAX_VALUE, + new SimpleFileVisitor<>() { + + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + + Files.delete(file); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path directory, IOException exc) + throws IOException { + + if (exc != null) { + throw exc; + } + + Files.delete(directory); + return FileVisitResult.CONTINUE; + } + }); + } + + /** tmp 내부 구조 검증 - 내부는 반드시 symlink만 허용 */ + private void validateTmpStructure(Path dir) { + + if (!Files.exists(dir, LinkOption.NOFOLLOW_LINKS)) { + return; + } + + try { + Files.walkFileTree( + dir, + EnumSet.noneOf(FileVisitOption.class), + Integer.MAX_VALUE, + new SimpleFileVisitor<>() { + + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + + if (!Files.isSymbolicLink(file)) { + log.error("tmp 내부에 일반 파일 존재: {}", file); + throw new CustomApiException( + "BAD_REQUEST", HttpStatus.BAD_REQUEST, "tmp 내부는 symlink만 허용"); + } + + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult preVisitDirectory(Path directory, BasicFileAttributes attrs) + throws IOException { + + if (!directory.equals(dir) && Files.isSymbolicLink(directory)) { + log.error("tmp 내부에 symlink 디렉토리 존재: {}", directory); + throw new CustomApiException( + "BAD_REQUEST", HttpStatus.BAD_REQUEST, "tmp 내부에 symlink 디렉토리 금지"); + } + + return FileVisitResult.CONTINUE; + } + }); + } catch (IOException e) { + throw new CustomApiException( + "INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "tmp 구조 검증 실패"); + } } /** diff --git a/src/main/java/com/kamco/cd/training/postgres/core/DatasetCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/DatasetCoreService.java index 7e10d2b..4c82418 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/DatasetCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/DatasetCoreService.java @@ -1,6 +1,5 @@ package com.kamco.cd.training.postgres.core; -import com.fasterxml.jackson.databind.ObjectMapper; import com.kamco.cd.training.common.enums.LearnDataRegister; import com.kamco.cd.training.common.enums.LearnDataType; import com.kamco.cd.training.common.exception.NotFoundException; @@ -15,6 +14,7 @@ import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetObjEntity; import com.kamco.cd.training.postgres.repository.dataset.DatasetObjRepository; import com.kamco.cd.training.postgres.repository.dataset.DatasetRepository; +import jakarta.transaction.Transactional; import java.time.ZonedDateTime; import java.util.List; import java.util.UUID; @@ -23,16 +23,15 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.data.domain.Page; import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; @Service @RequiredArgsConstructor @Slf4j public class DatasetCoreService implements BaseCoreService { + private final DatasetRepository datasetRepository; private final DatasetObjRepository datasetObjRepository; - private final ObjectMapper objectMapper; /** * 학습 데이터 삭제 @@ -221,7 +220,6 @@ public class DatasetCoreService return datasetRepository.insertDatasetMngData(mngRegDto); } - @Transactional public void insertDatasetObj(DatasetObjRegDto objRegDto) { datasetObjRepository.insertDatasetObj(objRegDto); } @@ -230,22 +228,26 @@ public class DatasetCoreService return datasetObjRepository.getFilePathByUUIDPathType(uuid, pathType); } - @Transactional public void insertDatasetTestObj(DatasetObjRegDto objRegDto) { datasetObjRepository.insertDatasetTestObj(objRegDto); } + /** + * 학습데이터셋 마스터 상태 변경 + * + * @param datasetUid 학습데이터셋 마스터 id + * @param register 상태 + */ @Transactional - public void updateDatasetUploadStatus(Long datasetUid) { + public void updateDatasetUploadStatus(Long datasetUid, LearnDataRegister register) { DatasetEntity entity = datasetRepository .findById(datasetUid) .orElseThrow(() -> new NotFoundException("데이터셋을 찾을 수 없습니다. ID: " + datasetUid)); - entity.setStatus(LearnDataRegister.COMPLETED.getId()); + entity.setStatus(register.getId()); } - @Transactional public void insertDatasetValObj(DatasetObjRegDto objRegDto) { datasetObjRepository.insertDatasetValObj(objRegDto); } @@ -253,4 +255,15 @@ public class DatasetCoreService public Long findDatasetByUidExistsCnt(String uid) { return datasetRepository.findDatasetByUidExistsCnt(uid); } + + /** + * 데이터셋 등록 실패시 Obj 데이터 정리 + * + * @param datasetUid 모델 마스터 id + */ + @Transactional + public void deleteAllDatasetObj(Long datasetUid) { + int cnt = datasetObjRepository.deleteAllDatasetObj(datasetUid); + log.info("datasetUid={} 데이터셋 실패 - 전체 삭제 완료. 총 {}건", datasetUid, cnt); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 33b53fd..63bee2c 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -341,6 +341,20 @@ public class ModelTrainMngCoreService { return entity.toDto(); } + /** + * 모델관리 조회 + * + * @param uuid + * @return + */ + public ModelTrainMngDto.Basic findModelByUuid(UUID uuid) { + ModelMasterEntity entity = + modelMngRepository + .findByUuid(uuid) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + uuid)); + return entity.toDto(); + } + /** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */ @Transactional public void markInProgress(Long modelId, Long jobId) { diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryCustom.java index 314bc00..8fe5cef 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryCustom.java @@ -24,4 +24,12 @@ public interface DatasetObjRepositoryCustom { void insertDatasetTestObj(DatasetObjRegDto objRegDto); void insertDatasetValObj(DatasetObjRegDto objRegDto); + + /** + * 데이터셋 등록 실패시 Obj 데이터 정리 + * + * @param datasetUid + * @return + */ + int deleteAllDatasetObj(Long datasetUid); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryImpl.java index ae21e48..151d0f5 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetObjRepositoryImpl.java @@ -40,6 +40,7 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom { private final JPAQueryFactory queryFactory; private final QDatasetEntity dataset = datasetEntity; + private final ObjectMapper objectMapper = new ObjectMapper(); @PersistenceContext EntityManager em; @@ -55,7 +56,6 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom { @Override public void insertDatasetTestObj(DatasetObjRegDto objRegDto) { - ObjectMapper objectMapper = new ObjectMapper(); String json; Geometry geometry; String geometryJson; @@ -99,7 +99,6 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom { @Override public void insertDatasetValObj(DatasetObjRegDto objRegDto) { - ObjectMapper objectMapper = new ObjectMapper(); String json; String geometryJson; try { @@ -219,7 +218,6 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom { @Override public void insertDatasetObj(DatasetObjRegDto objRegDto) { - ObjectMapper objectMapper = new ObjectMapper(); String json; String geometryJson; try { @@ -276,4 +274,38 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom { .where(datasetObjEntity.uuid.eq(uuid)) .fetchOne(); } + + @Override + public int deleteAllDatasetObj(Long datasetUid) { + int cnt = 0; + cnt = + em.createNativeQuery( + """ + delete from tb_dataset_obj + where dataset_uid = ? + """) + .setParameter(1, datasetUid) + .executeUpdate(); + + cnt += + em.createNativeQuery( + """ + delete from tb_dataset_val_obj + where dataset_uid = ? + """) + .setParameter(1, datasetUid) + .executeUpdate(); + + cnt += + em.createNativeQuery( + """ + delete from tb_dataset_test_obj + where dataset_uid = ? + """) + .setParameter(1, datasetUid) + .executeUpdate(); + + em.clear(); + return cnt; + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index d1b8ea4..0a893fd 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -84,18 +84,35 @@ public class ModelTrainMetricsJobService { for (CSVRecord record : parser) { int epoch = Integer.parseInt(record.get("Epoch")); - float aAcc = Float.parseFloat(record.get("aAcc")); - float mFscore = Float.parseFloat(record.get("mFscore")); - float mPrecision = Float.parseFloat(record.get("mPrecision")); - float mRecall = Float.parseFloat(record.get("mRecall")); - float mIoU = Float.parseFloat(record.get("mIoU")); - float mAcc = Float.parseFloat(record.get("mAcc")); - float changed_fscore = Float.parseFloat(record.get("changed_fscore")); - float changed_precision = Float.parseFloat(record.get("changed_precision")); - float changed_recall = Float.parseFloat(record.get("changed_recall")); - float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore")); - float unchanged_precision = Float.parseFloat(record.get("unchanged_precision")); - float unchanged_recall = Float.parseFloat(record.get("unchanged_recall")); + + float aAcc = parseFloatSafe(record.get("aAcc")); + float mFscore = parseFloatSafe(record.get("mFscore")); + float mPrecision = parseFloatSafe(record.get("mPrecision")); + float mRecall = parseFloatSafe(record.get("mRecall")); + float mIoU = parseFloatSafe(record.get("mIoU")); + float mAcc = parseFloatSafe(record.get("mAcc")); + + float changed_fscore = parseFloatSafe(record.get("changed_fscore")); + float changed_precision = parseFloatSafe(record.get("changed_precision")); + float changed_recall = parseFloatSafe(record.get("changed_recall")); + + float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore")); + float unchanged_precision = parseFloatSafe(record.get("unchanged_precision")); + float unchanged_recall = parseFloatSafe(record.get("unchanged_recall")); + // int epoch = Integer.parseInt(record.get("Epoch")); + // float aAcc = Float.parseFloat(record.get("aAcc")); + // float mFscore = Float.parseFloat(record.get("mFscore")); + // float mPrecision = Float.parseFloat(record.get("mPrecision")); + // float mRecall = Float.parseFloat(record.get("mRecall")); + // float mIoU = Float.parseFloat(record.get("mIoU")); + // float mAcc = Float.parseFloat(record.get("mAcc")); + // float changed_fscore = Float.parseFloat(record.get("changed_fscore")); + // float changed_precision = Float.parseFloat(record.get("changed_precision")); + // float changed_recall = Float.parseFloat(record.get("changed_recall")); + // float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore")); + // float unchanged_precision = + // Float.parseFloat(record.get("unchanged_precision")); + // float unchanged_recall = Float.parseFloat(record.get("unchanged_recall")); batchArgs.add( new Object[] { @@ -153,4 +170,23 @@ public class ModelTrainMetricsJobService { modelInfo.getModelId(), "step1"); } } + + private Float parseFloatSafe(String value) { + try { + if (value == null) return null; + + value = value.trim(); + + if (value.isEmpty()) return null; + + if (value.equalsIgnoreCase("nan")) return null; + + float f = Float.parseFloat(value); + + return Float.isNaN(f) ? null : f; + + } catch (Exception e) { + return null; + } + } }