diff --git a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java index 1c49afb..d3145db 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainDetailApiController.java @@ -369,9 +369,8 @@ public class ModelTrainDetailApiController { } @Operation( - summary = "학습 결과 ZIP 파일 전체 다운로드", - description = - "모델 UUID에 해당하는 모든 ZIP 파일 목록을 조회하고" + "생성된 모든 학습데이터의 ZIP 파일을 하나의 ZIP 파일로 압축하여 다운로드", + summary = "학습 결과 ZIP 파일 목록 조회", + description = "모델 UUID에 해당하는 모든 ZIP 파일 목록과 개별 다운로드 링크 반환", parameters = { @Parameter( name = "kamco-download-uuid", @@ -388,11 +387,12 @@ public class ModelTrainDetailApiController { value = { @ApiResponse( responseCode = "200", - description = "학습데이터 zip파일 다운로드", + description = "ZIP 파일 목록 조회 성공", content = @Content( - mediaType = "application/octet-stream", - schema = @Schema(type = "string", format = "binary"))), + mediaType = "application/json", + schema = + @Schema(implementation = ModelTrainDetailDto.ZipFileListResponse.class))), @ApiResponse(responseCode = "404", description = "모델 또는 파일 없음", content = @Content), @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) }) @@ -403,6 +403,7 @@ public class ModelTrainDetailApiController { HttpServletRequest request) throws IOException { - return modelTrainDetailService.downloadZipFile(uuid, downloadUuid, request); + return ResponseEntity.ok( + modelTrainDetailService.getZipFileListWithFullUrl(uuid, downloadUuid, request)); } } diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java index be8945b..2675493 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainDetailService.java @@ -444,6 +444,119 @@ public class ModelTrainDetailService { }); } + /** + * 모델 UUID로 ZIP 파일 목록 조회 및 전체 URL 다운로드 링크 생성 + * + * @param uuid 모델 UUID + * @param downloadUuid 다운로드 추적 UUID + * @param request HTTP 요청 + * @return ZIP 파일 목록 및 전체 URL 다운로드 링크 + */ + public ZipFileListResponse getZipFileListWithFullUrl( + UUID uuid, String downloadUuid, HttpServletRequest request) { + log.info("ZIP 파일 목록 조회 시작 (전체 URL): modelUuid={}, downloadUuid={}", uuid, downloadUuid); + + // 1. 모델 정보 조회 + Basic modelInfo; + try { + modelInfo = findByModelByUUID(uuid); + if (modelInfo == null) { + throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid); + } + } catch (NullPointerException e) { + log.error("모델 조회 실패: {}", uuid, e); + throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid); + } + + // 2. 실제 디렉토리 경로 찾기 (uuid 또는 uuid-out) + Path baseDir = findActualBasePath(uuid); + + if (baseDir == null || !Files.exists(baseDir)) { + log.warn( + "디렉토리를 찾을 수 없음: modelUuid={}, 시도한 경로: {} 또는 {}-out", + uuid, + responseDir + "/" + uuid, + responseDir + "/" + uuid); + throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 결과 디렉토리가 존재하지 않습니다."); + } + + log.info("디렉토리 발견: basePath={}", baseDir.toString()); + + // 요청에서 도메인 정보 추출 + String scheme = request.getScheme(); // http 또는 https + String serverName = request.getServerName(); // localhost 또는 도메인 + int serverPort = request.getServerPort(); // 8080 등 + String contextPath = request.getContextPath(); // /api 등 + + // 기본 URL 구성 + String baseUrl; + if ((scheme.equals("http") && serverPort == 80) + || (scheme.equals("https") && serverPort == 443)) { + baseUrl = scheme + "://" + serverName + contextPath; + } else { + baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath; + } + + // 3. ZIP 파일 목록 검색 + List zipFiles = new ArrayList<>(); + long totalSize = 0L; + + try (Stream stream = Files.list(baseDir)) { + List files = + stream + .filter(Files::isRegularFile) + .filter(p -> p.getFileName().toString().endsWith(".zip")) + .filter(p -> p.getFileName().toString().contains(uuid.toString())) + .sorted( + Comparator.comparing(p -> p.getFileName().toString(), Comparator.reverseOrder())) + .toList(); + + for (Path file : files) { + String fileName = file.getFileName().toString(); + long fileSize = Files.size(file); + totalSize += fileSize; + + // 파일명에서 버전 추출 + String version = extractVersionFromZipFileName(fileName); + boolean isCurrent = version.equals(modelInfo.getModelVer()); + + // 전체 도메인 URL 포함한 다운로드 링크 생성 + String downloadUrl = baseUrl + "/api/models/download/" + uuid + "?file=" + fileName; + + zipFiles.add( + ZipFileInfo.builder() + .fileName(fileName) + .filePath(file.toString()) + .version(version) + .fileSize(fileSize) + .fileSizeFormatted(formatFileSize(fileSize)) + .lastModified( + Files.getLastModifiedTime(file).toInstant().atZone(ZoneId.systemDefault())) + .isCurrent(isCurrent) + .downloadUrl(downloadUrl) + .build()); + } + + log.info("ZIP 파일 {}개 발견", zipFiles.size()); + + } catch (IOException e) { + log.error("ZIP 파일 목록 조회 실패: {}", baseDir, e); + throw new CustomApiException( + "INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "ZIP 파일 목록 조회 실패"); + } + + return ZipFileListResponse.builder() + .modelUuid(uuid.toString()) + .modelNo(modelInfo.getModelNo()) + .modelVer(modelInfo.getModelVer()) + .basePath(baseDir.toString()) + .zipFiles(zipFiles) + .totalFiles(zipFiles.size()) + .totalSize(totalSize) + .totalSizeFormatted(formatFileSize(totalSize)) + .build(); + } + /** * 모델 UUID로 ZIP 파일 목록 조회 및 다운로드 링크 생성 * @@ -608,16 +721,24 @@ public class ModelTrainDetailService { } /** - * 모델 UUID로 모든 ZIP 파일 다운로드 (여러 파일이 있으면 하나의 zip으로 묶어서) + * 모델 UUID로 모든 ZIP 파일 다운로드 또는 목록 조회 * * @param uuid 모델 UUID * @param downloadUuid 다운로드 추적 UUID + * @param fileName 개별 파일명 (선택) + * @param accept Accept 헤더 (application/json 또는 application/octet-stream) * @param request HTTP 요청 - * @return ZIP 파일 다운로드 응답 + * @return JSON 응답 또는 ZIP 파일 Binary 응답 */ public ResponseEntity downloadZipFile( - UUID uuid, String downloadUuid, HttpServletRequest request) throws IOException { - log.info("ZIP 파일 다운로드 시작: modelUuid={}, downloadUuid={}", uuid, downloadUuid); + UUID uuid, String downloadUuid, String fileName, String accept, HttpServletRequest request) + throws IOException { + log.info( + "ZIP 파일 다운로드/조회 시작: modelUuid={}, downloadUuid={}, file={}, accept={}", + uuid, + downloadUuid, + fileName, + accept); // 1. 모델 정보 조회 Basic modelInfo; @@ -659,7 +780,134 @@ public class ModelTrainDetailService { baseDir, zipFiles.stream().map(p -> p.getFileName().toString()).toList()); - // 4. 파일이 1개면 바로 다운로드 + // 4. Accept 헤더에 따라 분기 + if (accept != null && accept.contains("application/json")) { + // JSON 응답: ZIP 파일 목록과 다운로드 링크 반환 + log.info("JSON 응답 모드: ZIP 파일 목록 반환"); + return ResponseEntity.ok(buildZipFileListResponse(modelInfo, baseDir, zipFiles, request)); + } + + // 5. Binary 응답: 파일 다운로드 + return handleBinaryDownload(modelInfo, baseDir, zipFiles, fileName, uuid, request); + } + + /** + * ZIP 파일 목록 응답 생성 (JSON) + * + * @param modelInfo 모델 정보 + * @param baseDir 기본 디렉토리 + * @param zipFiles ZIP 파일 목록 + * @param request HTTP 요청 + * @return ZipFileListResponse + */ + private ZipFileListResponse buildZipFileListResponse( + Basic modelInfo, Path baseDir, List zipFiles, HttpServletRequest request) + throws IOException { + + // 요청에서 도메인 정보 추출 + String scheme = request.getScheme(); // http 또는 https + String serverName = request.getServerName(); // localhost 또는 도메인 + int serverPort = request.getServerPort(); // 8080 등 + String contextPath = request.getContextPath(); // /api 등 + + // 기본 URL 구성 + String baseUrl; + if ((scheme.equals("http") && serverPort == 80) + || (scheme.equals("https") && serverPort == 443)) { + baseUrl = scheme + "://" + serverName + contextPath; + } else { + baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath; + } + + List zipFileInfos = new ArrayList<>(); + long totalSize = 0L; + + for (Path zipFile : zipFiles) { + String zipFileName = zipFile.getFileName().toString(); + long fileSize = Files.size(zipFile); + totalSize += fileSize; + + // 파일명에서 버전 추출 + String version = extractVersionFromZipFileName(zipFileName); + boolean isCurrent = version.equals(modelInfo.getModelVer()); + + // 전체 도메인 URL 포함한 다운로드 링크 생성 + String downloadUrl = + baseUrl + "/api/models/downloadzip/" + modelInfo.getUuid() + "?file=" + zipFileName; + + zipFileInfos.add( + ZipFileInfo.builder() + .fileName(zipFileName) + .filePath(zipFile.toString()) + .version(version) + .fileSize(fileSize) + .fileSizeFormatted(formatFileSize(fileSize)) + .lastModified( + Files.getLastModifiedTime(zipFile).toInstant().atZone(ZoneId.systemDefault())) + .isCurrent(isCurrent) + .downloadUrl(downloadUrl) + .build()); + } + + return ZipFileListResponse.builder() + .modelUuid(modelInfo.getUuid().toString()) + .modelNo(modelInfo.getModelNo()) + .modelVer(modelInfo.getModelVer()) + .basePath(baseDir.toString()) + .zipFiles(zipFileInfos) + .totalFiles(zipFileInfos.size()) + .totalSize(totalSize) + .totalSizeFormatted(formatFileSize(totalSize)) + .build(); + } + + /** + * Binary 다운로드 처리 + * + * @param modelInfo 모델 정보 + * @param baseDir 기본 디렉토리 + * @param zipFiles ZIP 파일 목록 + * @param fileName 개별 파일명 (선택) + * @param uuid 모델 UUID + * @param request HTTP 요청 + * @return Binary 응답 + */ + private ResponseEntity handleBinaryDownload( + Basic modelInfo, + Path baseDir, + List zipFiles, + String fileName, + UUID uuid, + HttpServletRequest request) + throws IOException { + + // file 파라미터가 있으면 개별 파일 다운로드 + if (fileName != null && !fileName.isEmpty()) { + log.info("개별 파일 다운로드 요청: fileName={}", fileName); + + Path targetFile = + zipFiles.stream() + .filter(p -> p.getFileName().toString().equals(fileName)) + .findFirst() + .orElse(null); + + if (targetFile == null) { + log.warn("요청한 파일을 찾을 수 없음: fileName={}, basePath={}", fileName, baseDir); + throw new CustomApiException( + "NOT_FOUND", HttpStatus.NOT_FOUND, "요청한 파일을 찾을 수 없습니다: " + fileName); + } + + log.info( + "개별 ZIP 파일 다운로드: fileName={}, fileSize={} bytes, basePath={}", + targetFile.getFileName(), + Files.size(targetFile), + baseDir); + + return rangeDownloadResponder.buildZipResponse(targetFile, fileName, request); + } + + // file 파라미터가 없으면 모든 파일 다운로드 + // 파일이 1개면 바로 다운로드 if (zipFiles.size() == 1) { Path zipPath = zipFiles.get(0); log.info( @@ -671,7 +919,7 @@ public class ModelTrainDetailService { return rangeDownloadResponder.buildZipResponse(zipPath, downloadFileName, request); } - // 5. 파일이 여러 개면 하나의 zip으로 묶어서 다운로드 + // 파일이 여러 개면 하나의 zip으로 묶어서 다운로드 log.info("여러 ZIP 파일을 하나로 묶어서 다운로드: 총 {}개 파일", zipFiles.size()); String combinedZipName = modelInfo.getModelNo() + "." + modelInfo.getModelVer() + ".all.zip"; diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java index 3be3328..24b0cc2 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java @@ -51,4 +51,16 @@ public class ModelTestMetricsJobCoreService { public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) { modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState); } + + /** + * 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용) + * + * @param modelId 모델 ID + * @param epoch Epoch 번호 + * @param metricType 메트릭 타입 (fscore, precision, recall) + * @return 메트릭 JSON DTO + */ + public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) { + return modelTestMetricsJobRepository.getMetricsByEpoch(modelId, epoch, metricType); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java index fd7a25f..72ff12f 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java @@ -23,4 +23,14 @@ public interface ModelTestMetricsJobRepositoryCustom { void updatePackingStart(Long modelId, ZonedDateTime now); void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState); + + /** + * 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용) + * + * @param modelId 모델 ID + * @param epoch Epoch 번호 + * @param metricType 메트릭 타입 (fscore, precision, recall) + * @return 메트릭 JSON DTO + */ + ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java index 687dbfb..4a30413 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java @@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.train; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity; +import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity; import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; @@ -187,4 +188,49 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport .where(modelMasterEntity.id.eq(modelId)) .execute(); } + + @Override + public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) { + return queryFactory + .select( + Projections.constructor( + ModelMetricJsonDto.class, + modelMasterEntity.modelNo, + modelMasterEntity.modelVer, + com.querydsl.core.types.dsl.Expressions.constant(metricType), + com.querydsl.core.types.dsl.Expressions.constant(epoch), + Projections.constructor( + Properties.class, + // Changed 메트릭 + modelMetricsValidationEntity.changedFscore, + modelMetricsValidationEntity.changedPrecision, + modelMetricsValidationEntity.changedRecall, + // Unchanged 메트릭 + modelMetricsValidationEntity.unchangedFscore, + modelMetricsValidationEntity.unchangedPrecision, + modelMetricsValidationEntity.unchangedRecall, + // Mean 메트릭 + modelMetricsValidationEntity.mFscore, + modelMetricsValidationEntity.mPrecision, + modelMetricsValidationEntity.mRecall, + modelMetricsValidationEntity.mIou, + modelMetricsValidationEntity.mAcc, + // Overall Accuracy + modelMetricsValidationEntity.aAcc, + // Train 메트릭 (Loss, LR, Duration) + modelMetricsTrainEntity.loss, + modelMetricsTrainEntity.lr, + modelMetricsTrainEntity.durationTime))) + .from(modelMetricsValidationEntity) + .innerJoin(modelMasterEntity) + .on(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.id)) + .leftJoin(modelMetricsTrainEntity) + .on( + modelMetricsTrainEntity.model.id.eq(modelMasterEntity.id), + modelMetricsTrainEntity.epoch.eq(epoch)) + .where( + modelMetricsValidationEntity.model.id.eq(modelId), + modelMetricsValidationEntity.epoch.eq(epoch)) + .fetchOne(); + } } diff --git a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java index 3e54e4e..05efdef 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java +++ b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java @@ -2,6 +2,7 @@ package com.kamco.cd.training.train.dto; import com.fasterxml.jackson.annotation.JsonProperty; import io.swagger.v3.oas.annotations.media.Schema; +import java.nio.file.Path; import java.util.Properties; import java.util.UUID; import lombok.AllArgsConstructor; @@ -23,6 +24,17 @@ public class ModelTrainMetricsDto { private UUID uuid; } + @Schema(name = "BestPthInfo", description = "Best PTH 파일 정보") + @Getter + @AllArgsConstructor + public static class BestPthInfo { + + private String fileName; // best_changed_fscore_epoch_3.pth + private String metricType; // fscore, precision, recall + private Integer epoch; // 3 + private Path filePath; // 파일 전체 경로 + } + @Getter @AllArgsConstructor public static class ModelMetricJsonDto { @@ -33,6 +45,11 @@ public class ModelTrainMetricsDto { @JsonProperty("model_version") private String modelVersion; + @JsonProperty("metric_type") + private String metricType; // fscore, precision, recall + + private Integer epoch; // epoch 번호 + private Properties properties; } @@ -40,13 +57,55 @@ public class ModelTrainMetricsDto { @AllArgsConstructor public static class Properties { - @JsonProperty("f1_score") - private Float f1Score; + // 변화 탐지 관련 메트릭 (Changed) + @JsonProperty("changed_fscore") + private Float changedFscore; - private Float precision; - private Float recall; - private Float loss; - private Double iou; + @JsonProperty("changed_precision") + private Float changedPrecision; + + @JsonProperty("changed_recall") + private Float changedRecall; + + // 비변화 관련 메트릭 (Unchanged) + @JsonProperty("unchanged_fscore") + private Float unchangedFscore; + + @JsonProperty("unchanged_precision") + private Float unchangedPrecision; + + @JsonProperty("unchanged_recall") + private Float unchangedRecall; + + // 평균 메트릭 (Mean) + @JsonProperty("mean_fscore") + private Float mFscore; + + @JsonProperty("mean_precision") + private Float mPrecision; + + @JsonProperty("mean_recall") + private Float mRecall; + + @JsonProperty("mean_iou") + private Float mIou; + + @JsonProperty("mean_accuracy") + private Float mAcc; + + // 전체 정확도 (Overall Accuracy) + @JsonProperty("overall_accuracy") + private Float aAcc; + + // 학습 관련 메트릭 + @JsonProperty("loss") + private Double loss; + + @JsonProperty("learning_rate") + private Double lr; + + @JsonProperty("duration_time") + private Float durationTime; } @Getter diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index a0c1447..af04036 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.BestPthInfo; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto; -import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.io.BufferedReader; import java.io.IOException; @@ -18,7 +18,8 @@ import java.nio.file.StandardOpenOption; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.List; -import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.zip.ZipEntry; @@ -76,9 +77,9 @@ public class ModelTestMetricsJobService { } /** - * 베스트 에폭 zip파일 생성, 테스트결과 db등록 + * 베스트 에폭별 개별 ZIP 파일 생성, 테스트결과 db등록 * - * @param modelInfo + * @param modelInfo 모델 정보 */ private void createFile(ResponsePathDto modelInfo) { @@ -130,59 +131,42 @@ public class ModelTestMetricsJobService { throw new RuntimeException(e); } - // 패키징할 파일 만들기 + // 패키징 시작 modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now()); - ModelMetricJsonDto jsonDto = - modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId()); - try { - writeJsonFile( - jsonDto, - Paths.get( - responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json")); - } catch (IOException e) { - throw new RuntimeException(e); - } - - Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid()); - - ModelTestFileName fileInfo = - modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId()); - - Path zipPath = - Paths.get( - responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip"); - Set targetNames = - Set.of( - "model_config.py", - fileInfo.getBestEpochFileName() + ".pth", - fileInfo.getModelVersion() + ".json"); - - List files = new ArrayList<>(); - try (Stream s = Files.list(responsePath)) { - files.addAll( - s.filter(Files::isRegularFile) - .filter(p -> targetNames.contains(p.getFileName().toString())) - .collect(Collectors.toList())); - } catch (IOException e) { - throw new RuntimeException(e); - } - - try (Stream s = Files.list(Path.of(ptPathDir))) { - files.addAll( - s.filter(Files::isRegularFile) - .limit(1) // yolov8_6th-6m.pt 파일 1개만 - .collect(Collectors.toList())); - } catch (IOException e) { - throw new RuntimeException(e); - } + Path responsePath = Paths.get(responseDir, modelInfo.getUuid().toString()); try { - zipFiles(files, zipPath); + // 1. 모든 Best PTH 파일 찾기 + List bestPthFiles = findAllBestPthFiles(responsePath); + + if (bestPthFiles.isEmpty()) { + log.warn( + "Best PTH 파일을 찾을 수 없습니다: modelId={}, uuid={}", + modelInfo.getModelId(), + modelInfo.getUuid()); + modelTestMetricsJobCoreService.updatePackingEnd( + modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId()); + return; + } + + log.info( + "Best PTH 파일 {}개 발견: modelId={}, files={}", + bestPthFiles.size(), + modelInfo.getModelId(), + bestPthFiles.stream().map(BestPthInfo::getFileName).collect(Collectors.toList())); + + // 2. 각 Best PTH별로 개별 ZIP 생성 + createIndividualZipFiles(modelInfo, bestPthFiles, responsePath); modelTestMetricsJobCoreService.updatePackingEnd( modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId()); - } catch (IOException e) { + + log.info( + "모든 ZIP 파일 생성 완료: modelId={}, zipCount={}", modelInfo.getModelId(), bestPthFiles.size()); + + } catch (Exception e) { + log.error("ZIP 파일 생성 실패: modelId={}", modelInfo.getModelId(), e); modelTestMetricsJobCoreService.updatePackingEnd( modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId()); throw new RuntimeException(e); @@ -227,4 +211,130 @@ public class ModelTestMetricsJobService { } } } + + /** + * Response 폴더에서 모든 best_changed_*.pth 파일 찾기 + * + * @param responsePath Response 디렉토리 경로 + * @return Best PTH 파일 정보 리스트 + */ + private List findAllBestPthFiles(Path responsePath) throws IOException { + List bestFiles = new ArrayList<>(); + Pattern pattern = Pattern.compile("best_changed_(fscore|precision|recall)_epoch_(\\d+)\\.pth"); + + try (Stream stream = Files.list(responsePath)) { + stream + .filter(Files::isRegularFile) + .forEach( + file -> { + String fileName = file.getFileName().toString(); + Matcher matcher = pattern.matcher(fileName); + if (matcher.matches()) { + String metricType = matcher.group(1); // fscore, precision, recall + Integer epoch = Integer.parseInt(matcher.group(2)); + + bestFiles.add(new BestPthInfo(fileName, metricType, epoch, file)); + } + }); + } + + return bestFiles; + } + + /** + * 각 Best PTH 파일별로 개별 ZIP 파일 생성 + * + * @param modelInfo 모델 정보 + * @param bestPthFiles Best PTH 파일 리스트 + * @param responsePath Response 디렉토리 경로 + */ + private void createIndividualZipFiles( + ResponsePathDto modelInfo, List bestPthFiles, Path responsePath) { + + // 공통 파일 경로 + Path modelConfigPath = responsePath.resolve("model_config.py"); + Path ptPath = findPretrainedModel(); + + // 공통 파일 존재 확인 + if (!Files.exists(modelConfigPath)) { + log.error("model_config.py 파일이 존재하지 않습니다: {}", modelConfigPath); + throw new RuntimeException("필수 파일 누락: model_config.py"); + } + + if (ptPath == null) { + log.error("사전학습 모델(.pt) 파일을 찾을 수 없습니다."); + throw new RuntimeException("필수 파일 누락: .pt"); + } + + for (BestPthInfo bestPth : bestPthFiles) { + try { + // 1. 메트릭 JSON 생성 + ModelMetricJsonDto metricJson = + modelTestMetricsJobCoreService.getMetricsByEpoch( + modelInfo.getModelId(), bestPth.getEpoch(), bestPth.getMetricType()); + + if (metricJson == null) { + log.warn( + "메트릭 정보 없음: modelId={}, epoch={}, metricType={}", + modelInfo.getModelId(), + bestPth.getEpoch(), + bestPth.getMetricType()); + continue; + } + + String jsonFileName = bestPth.getMetricType() + "_metrics.json"; + Path jsonPath = responsePath.resolve(jsonFileName); + writeJsonFile(metricJson, jsonPath); + + // 2. ZIP에 포함할 파일 리스트 + List filesToZip = new ArrayList<>(); + filesToZip.add(modelConfigPath); // model_config.py + filesToZip.add(bestPth.getFilePath()); // best_changed_{type}_epoch_{n}.pth + filesToZip.add(jsonPath); // {type}_metrics.json + filesToZip.add(ptPath); // yolov8_6th-6m.pt + + // 3. ZIP 파일명 생성 + String zipFileName = + String.format( + "%s.%s.%s.%s.zip", + metricJson.getCdModelType(), // G1 + metricJson.getModelVersion(), // G1_000001 + modelInfo.getUuid(), // uuid + bestPth.getMetricType()); // fscore/precision/recall + + Path zipPath = responsePath.resolve(zipFileName); + + // 4. ZIP 압축 + zipFiles(filesToZip, zipPath); + + log.info("ZIP 파일 생성 완료: {}", zipPath); + + } catch (Exception e) { + log.error( + "ZIP 파일 생성 실패: metricType={}, epoch={}", + bestPth.getMetricType(), + bestPth.getEpoch(), + e); + throw new RuntimeException(e); + } + } + } + + /** + * 사전학습 모델(PT 파일) 찾기 + * + * @return PT 파일 경로 + */ + private Path findPretrainedModel() { + try (Stream stream = Files.list(Path.of(ptPathDir))) { + return stream + .filter(Files::isRegularFile) + .filter(p -> p.getFileName().toString().endsWith(".pt")) + .findFirst() + .orElse(null); + } catch (IOException e) { + log.error("사전학습 모델 찾기 실패", e); + return null; + } + } }