daniel 작업 - best epoch의 pth 만큼 zip파일 생성

This commit is contained in:
2026-04-09 13:15:49 +09:00
parent 02954aa439
commit 8b7ff0162d
7 changed files with 555 additions and 69 deletions

View File

@@ -369,9 +369,8 @@ public class ModelTrainDetailApiController {
} }
@Operation( @Operation(
summary = "학습 결과 ZIP 파일 전체 다운로드", summary = "학습 결과 ZIP 파일 목록 조회",
description = description = "모델 UUID에 해당하는 모든 ZIP 파일 목록과 개별 다운로드 링크 반환",
"모델 UUID에 해당하는 모든 ZIP 파일 목록을 조회하고" + "생성된 모든 학습데이터의 ZIP 파일을 하나의 ZIP 파일로 압축하여 다운로드",
parameters = { parameters = {
@Parameter( @Parameter(
name = "kamco-download-uuid", name = "kamco-download-uuid",
@@ -388,11 +387,12 @@ public class ModelTrainDetailApiController {
value = { value = {
@ApiResponse( @ApiResponse(
responseCode = "200", responseCode = "200",
description = "학습데이터 zip파일 다운로드", description = "ZIP 파일 목록 조회 성공",
content = content =
@Content( @Content(
mediaType = "application/octet-stream", mediaType = "application/json",
schema = @Schema(type = "string", format = "binary"))), schema =
@Schema(implementation = ModelTrainDetailDto.ZipFileListResponse.class))),
@ApiResponse(responseCode = "404", description = "모델 또는 파일 없음", content = @Content), @ApiResponse(responseCode = "404", description = "모델 또는 파일 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@@ -403,6 +403,7 @@ public class ModelTrainDetailApiController {
HttpServletRequest request) HttpServletRequest request)
throws IOException { throws IOException {
return modelTrainDetailService.downloadZipFile(uuid, downloadUuid, request); return ResponseEntity.ok(
modelTrainDetailService.getZipFileListWithFullUrl(uuid, downloadUuid, request));
} }
} }

View File

@@ -444,6 +444,119 @@ public class ModelTrainDetailService {
}); });
} }
/**
* 모델 UUID로 ZIP 파일 목록 조회 및 전체 URL 다운로드 링크 생성
*
* @param uuid 모델 UUID
* @param downloadUuid 다운로드 추적 UUID
* @param request HTTP 요청
* @return ZIP 파일 목록 및 전체 URL 다운로드 링크
*/
public ZipFileListResponse getZipFileListWithFullUrl(
UUID uuid, String downloadUuid, HttpServletRequest request) {
log.info("ZIP 파일 목록 조회 시작 (전체 URL): modelUuid={}, downloadUuid={}", uuid, downloadUuid);
// 1. 모델 정보 조회
Basic modelInfo;
try {
modelInfo = findByModelByUUID(uuid);
if (modelInfo == null) {
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid);
}
} catch (NullPointerException e) {
log.error("모델 조회 실패: {}", uuid, e);
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid);
}
// 2. 실제 디렉토리 경로 찾기 (uuid 또는 uuid-out)
Path baseDir = findActualBasePath(uuid);
if (baseDir == null || !Files.exists(baseDir)) {
log.warn(
"디렉토리를 찾을 수 없음: modelUuid={}, 시도한 경로: {} 또는 {}-out",
uuid,
responseDir + "/" + uuid,
responseDir + "/" + uuid);
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 결과 디렉토리가 존재하지 않습니다.");
}
log.info("디렉토리 발견: basePath={}", baseDir.toString());
// 요청에서 도메인 정보 추출
String scheme = request.getScheme(); // http 또는 https
String serverName = request.getServerName(); // localhost 또는 도메인
int serverPort = request.getServerPort(); // 8080 등
String contextPath = request.getContextPath(); // /api 등
// 기본 URL 구성
String baseUrl;
if ((scheme.equals("http") && serverPort == 80)
|| (scheme.equals("https") && serverPort == 443)) {
baseUrl = scheme + "://" + serverName + contextPath;
} else {
baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath;
}
// 3. ZIP 파일 목록 검색
List<ZipFileInfo> zipFiles = new ArrayList<>();
long totalSize = 0L;
try (Stream<Path> stream = Files.list(baseDir)) {
List<Path> files =
stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().endsWith(".zip"))
.filter(p -> p.getFileName().toString().contains(uuid.toString()))
.sorted(
Comparator.comparing(p -> p.getFileName().toString(), Comparator.reverseOrder()))
.toList();
for (Path file : files) {
String fileName = file.getFileName().toString();
long fileSize = Files.size(file);
totalSize += fileSize;
// 파일명에서 버전 추출
String version = extractVersionFromZipFileName(fileName);
boolean isCurrent = version.equals(modelInfo.getModelVer());
// 전체 도메인 URL 포함한 다운로드 링크 생성
String downloadUrl = baseUrl + "/api/models/download/" + uuid + "?file=" + fileName;
zipFiles.add(
ZipFileInfo.builder()
.fileName(fileName)
.filePath(file.toString())
.version(version)
.fileSize(fileSize)
.fileSizeFormatted(formatFileSize(fileSize))
.lastModified(
Files.getLastModifiedTime(file).toInstant().atZone(ZoneId.systemDefault()))
.isCurrent(isCurrent)
.downloadUrl(downloadUrl)
.build());
}
log.info("ZIP 파일 {}개 발견", zipFiles.size());
} catch (IOException e) {
log.error("ZIP 파일 목록 조회 실패: {}", baseDir, e);
throw new CustomApiException(
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "ZIP 파일 목록 조회 실패");
}
return ZipFileListResponse.builder()
.modelUuid(uuid.toString())
.modelNo(modelInfo.getModelNo())
.modelVer(modelInfo.getModelVer())
.basePath(baseDir.toString())
.zipFiles(zipFiles)
.totalFiles(zipFiles.size())
.totalSize(totalSize)
.totalSizeFormatted(formatFileSize(totalSize))
.build();
}
/** /**
* 모델 UUID로 ZIP 파일 목록 조회 및 다운로드 링크 생성 * 모델 UUID로 ZIP 파일 목록 조회 및 다운로드 링크 생성
* *
@@ -608,16 +721,24 @@ public class ModelTrainDetailService {
} }
/** /**
* 모델 UUID로 모든 ZIP 파일 다운로드 (여러 파일이 있으면 하나의 zip으로 묶어서) * 모델 UUID로 모든 ZIP 파일 다운로드 또는 목록 조회
* *
* @param uuid 모델 UUID * @param uuid 모델 UUID
* @param downloadUuid 다운로드 추적 UUID * @param downloadUuid 다운로드 추적 UUID
* @param fileName 개별 파일명 (선택)
* @param accept Accept 헤더 (application/json 또는 application/octet-stream)
* @param request HTTP 요청 * @param request HTTP 요청
* @return ZIP 파일 다운로드 응답 * @return JSON 응답 또는 ZIP 파일 Binary 응답
*/ */
public ResponseEntity<?> downloadZipFile( public ResponseEntity<?> downloadZipFile(
UUID uuid, String downloadUuid, HttpServletRequest request) throws IOException { UUID uuid, String downloadUuid, String fileName, String accept, HttpServletRequest request)
log.info("ZIP 파일 다운로드 시작: modelUuid={}, downloadUuid={}", uuid, downloadUuid); throws IOException {
log.info(
"ZIP 파일 다운로드/조회 시작: modelUuid={}, downloadUuid={}, file={}, accept={}",
uuid,
downloadUuid,
fileName,
accept);
// 1. 모델 정보 조회 // 1. 모델 정보 조회
Basic modelInfo; Basic modelInfo;
@@ -659,7 +780,134 @@ public class ModelTrainDetailService {
baseDir, baseDir,
zipFiles.stream().map(p -> p.getFileName().toString()).toList()); zipFiles.stream().map(p -> p.getFileName().toString()).toList());
// 4. 파일이 1개면 바로 다운로드 // 4. Accept 헤더에 따라 분기
if (accept != null && accept.contains("application/json")) {
// JSON 응답: ZIP 파일 목록과 다운로드 링크 반환
log.info("JSON 응답 모드: ZIP 파일 목록 반환");
return ResponseEntity.ok(buildZipFileListResponse(modelInfo, baseDir, zipFiles, request));
}
// 5. Binary 응답: 파일 다운로드
return handleBinaryDownload(modelInfo, baseDir, zipFiles, fileName, uuid, request);
}
/**
* ZIP 파일 목록 응답 생성 (JSON)
*
* @param modelInfo 모델 정보
* @param baseDir 기본 디렉토리
* @param zipFiles ZIP 파일 목록
* @param request HTTP 요청
* @return ZipFileListResponse
*/
private ZipFileListResponse buildZipFileListResponse(
Basic modelInfo, Path baseDir, List<Path> zipFiles, HttpServletRequest request)
throws IOException {
// 요청에서 도메인 정보 추출
String scheme = request.getScheme(); // http 또는 https
String serverName = request.getServerName(); // localhost 또는 도메인
int serverPort = request.getServerPort(); // 8080 등
String contextPath = request.getContextPath(); // /api 등
// 기본 URL 구성
String baseUrl;
if ((scheme.equals("http") && serverPort == 80)
|| (scheme.equals("https") && serverPort == 443)) {
baseUrl = scheme + "://" + serverName + contextPath;
} else {
baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath;
}
List<ZipFileInfo> zipFileInfos = new ArrayList<>();
long totalSize = 0L;
for (Path zipFile : zipFiles) {
String zipFileName = zipFile.getFileName().toString();
long fileSize = Files.size(zipFile);
totalSize += fileSize;
// 파일명에서 버전 추출
String version = extractVersionFromZipFileName(zipFileName);
boolean isCurrent = version.equals(modelInfo.getModelVer());
// 전체 도메인 URL 포함한 다운로드 링크 생성
String downloadUrl =
baseUrl + "/api/models/downloadzip/" + modelInfo.getUuid() + "?file=" + zipFileName;
zipFileInfos.add(
ZipFileInfo.builder()
.fileName(zipFileName)
.filePath(zipFile.toString())
.version(version)
.fileSize(fileSize)
.fileSizeFormatted(formatFileSize(fileSize))
.lastModified(
Files.getLastModifiedTime(zipFile).toInstant().atZone(ZoneId.systemDefault()))
.isCurrent(isCurrent)
.downloadUrl(downloadUrl)
.build());
}
return ZipFileListResponse.builder()
.modelUuid(modelInfo.getUuid().toString())
.modelNo(modelInfo.getModelNo())
.modelVer(modelInfo.getModelVer())
.basePath(baseDir.toString())
.zipFiles(zipFileInfos)
.totalFiles(zipFileInfos.size())
.totalSize(totalSize)
.totalSizeFormatted(formatFileSize(totalSize))
.build();
}
/**
* Binary 다운로드 처리
*
* @param modelInfo 모델 정보
* @param baseDir 기본 디렉토리
* @param zipFiles ZIP 파일 목록
* @param fileName 개별 파일명 (선택)
* @param uuid 모델 UUID
* @param request HTTP 요청
* @return Binary 응답
*/
private ResponseEntity<?> handleBinaryDownload(
Basic modelInfo,
Path baseDir,
List<Path> zipFiles,
String fileName,
UUID uuid,
HttpServletRequest request)
throws IOException {
// file 파라미터가 있으면 개별 파일 다운로드
if (fileName != null && !fileName.isEmpty()) {
log.info("개별 파일 다운로드 요청: fileName={}", fileName);
Path targetFile =
zipFiles.stream()
.filter(p -> p.getFileName().toString().equals(fileName))
.findFirst()
.orElse(null);
if (targetFile == null) {
log.warn("요청한 파일을 찾을 수 없음: fileName={}, basePath={}", fileName, baseDir);
throw new CustomApiException(
"NOT_FOUND", HttpStatus.NOT_FOUND, "요청한 파일을 찾을 수 없습니다: " + fileName);
}
log.info(
"개별 ZIP 파일 다운로드: fileName={}, fileSize={} bytes, basePath={}",
targetFile.getFileName(),
Files.size(targetFile),
baseDir);
return rangeDownloadResponder.buildZipResponse(targetFile, fileName, request);
}
// file 파라미터가 없으면 모든 파일 다운로드
// 파일이 1개면 바로 다운로드
if (zipFiles.size() == 1) { if (zipFiles.size() == 1) {
Path zipPath = zipFiles.get(0); Path zipPath = zipFiles.get(0);
log.info( log.info(
@@ -671,7 +919,7 @@ public class ModelTrainDetailService {
return rangeDownloadResponder.buildZipResponse(zipPath, downloadFileName, request); return rangeDownloadResponder.buildZipResponse(zipPath, downloadFileName, request);
} }
// 5. 파일이 여러 개면 하나의 zip으로 묶어서 다운로드 // 파일이 여러 개면 하나의 zip으로 묶어서 다운로드
log.info("여러 ZIP 파일을 하나로 묶어서 다운로드: 총 {}개 파일", zipFiles.size()); log.info("여러 ZIP 파일을 하나로 묶어서 다운로드: 총 {}개 파일", zipFiles.size());
String combinedZipName = modelInfo.getModelNo() + "." + modelInfo.getModelVer() + ".all.zip"; String combinedZipName = modelInfo.getModelNo() + "." + modelInfo.getModelVer() + ".all.zip";

View File

@@ -51,4 +51,16 @@ public class ModelTestMetricsJobCoreService {
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) { public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState); modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState);
} }
/**
* 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용)
*
* @param modelId 모델 ID
* @param epoch Epoch 번호
* @param metricType 메트릭 타입 (fscore, precision, recall)
* @return 메트릭 JSON DTO
*/
public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) {
return modelTestMetricsJobRepository.getMetricsByEpoch(modelId, epoch, metricType);
}
} }

View File

@@ -23,4 +23,14 @@ public interface ModelTestMetricsJobRepositoryCustom {
void updatePackingStart(Long modelId, ZonedDateTime now); void updatePackingStart(Long modelId, ZonedDateTime now);
void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState); void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState);
/**
* 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용)
*
* @param modelId 모델 ID
* @param epoch Epoch 번호
* @param metricType 메트릭 타입 (fscore, precision, recall)
* @return 메트릭 JSON DTO
*/
ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType);
} }

View File

@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
@@ -187,4 +188,49 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
.where(modelMasterEntity.id.eq(modelId)) .where(modelMasterEntity.id.eq(modelId))
.execute(); .execute();
} }
@Override
public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) {
return queryFactory
.select(
Projections.constructor(
ModelMetricJsonDto.class,
modelMasterEntity.modelNo,
modelMasterEntity.modelVer,
com.querydsl.core.types.dsl.Expressions.constant(metricType),
com.querydsl.core.types.dsl.Expressions.constant(epoch),
Projections.constructor(
Properties.class,
// Changed 메트릭
modelMetricsValidationEntity.changedFscore,
modelMetricsValidationEntity.changedPrecision,
modelMetricsValidationEntity.changedRecall,
// Unchanged 메트릭
modelMetricsValidationEntity.unchangedFscore,
modelMetricsValidationEntity.unchangedPrecision,
modelMetricsValidationEntity.unchangedRecall,
// Mean 메트릭
modelMetricsValidationEntity.mFscore,
modelMetricsValidationEntity.mPrecision,
modelMetricsValidationEntity.mRecall,
modelMetricsValidationEntity.mIou,
modelMetricsValidationEntity.mAcc,
// Overall Accuracy
modelMetricsValidationEntity.aAcc,
// Train 메트릭 (Loss, LR, Duration)
modelMetricsTrainEntity.loss,
modelMetricsTrainEntity.lr,
modelMetricsTrainEntity.durationTime)))
.from(modelMetricsValidationEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.id))
.leftJoin(modelMetricsTrainEntity)
.on(
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.id),
modelMetricsTrainEntity.epoch.eq(epoch))
.where(
modelMetricsValidationEntity.model.id.eq(modelId),
modelMetricsValidationEntity.epoch.eq(epoch))
.fetchOne();
}
} }

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.train.dto;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.nio.file.Path;
import java.util.Properties; import java.util.Properties;
import java.util.UUID; import java.util.UUID;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@@ -23,6 +24,17 @@ public class ModelTrainMetricsDto {
private UUID uuid; private UUID uuid;
} }
@Schema(name = "BestPthInfo", description = "Best PTH 파일 정보")
@Getter
@AllArgsConstructor
public static class BestPthInfo {
private String fileName; // best_changed_fscore_epoch_3.pth
private String metricType; // fscore, precision, recall
private Integer epoch; // 3
private Path filePath; // 파일 전체 경로
}
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public static class ModelMetricJsonDto { public static class ModelMetricJsonDto {
@@ -33,6 +45,11 @@ public class ModelTrainMetricsDto {
@JsonProperty("model_version") @JsonProperty("model_version")
private String modelVersion; private String modelVersion;
@JsonProperty("metric_type")
private String metricType; // fscore, precision, recall
private Integer epoch; // epoch 번호
private Properties properties; private Properties properties;
} }
@@ -40,13 +57,55 @@ public class ModelTrainMetricsDto {
@AllArgsConstructor @AllArgsConstructor
public static class Properties { public static class Properties {
@JsonProperty("f1_score") // 변화 탐지 관련 메트릭 (Changed)
private Float f1Score; @JsonProperty("changed_fscore")
private Float changedFscore;
private Float precision; @JsonProperty("changed_precision")
private Float recall; private Float changedPrecision;
private Float loss;
private Double iou; @JsonProperty("changed_recall")
private Float changedRecall;
// 비변화 관련 메트릭 (Unchanged)
@JsonProperty("unchanged_fscore")
private Float unchangedFscore;
@JsonProperty("unchanged_precision")
private Float unchangedPrecision;
@JsonProperty("unchanged_recall")
private Float unchangedRecall;
// 평균 메트릭 (Mean)
@JsonProperty("mean_fscore")
private Float mFscore;
@JsonProperty("mean_precision")
private Float mPrecision;
@JsonProperty("mean_recall")
private Float mRecall;
@JsonProperty("mean_iou")
private Float mIou;
@JsonProperty("mean_accuracy")
private Float mAcc;
// 전체 정확도 (Overall Accuracy)
@JsonProperty("overall_accuracy")
private Float aAcc;
// 학습 관련 메트릭
@JsonProperty("loss")
private Double loss;
@JsonProperty("learning_rate")
private Double lr;
@JsonProperty("duration_time")
private Float durationTime;
} }
@Getter @Getter

View File

@@ -4,8 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.SerializationFeature;
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.BestPthInfo;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
@@ -18,7 +18,8 @@ import java.nio.file.StandardOpenOption;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import java.util.zip.ZipEntry; import java.util.zip.ZipEntry;
@@ -76,9 +77,9 @@ public class ModelTestMetricsJobService {
} }
/** /**
* 베스트 에폭 zip파일 생성, 테스트결과 db등록 * 베스트 에폭별 개별 ZIP 파일 생성, 테스트결과 db등록
* *
* @param modelInfo * @param modelInfo 모델 정보
*/ */
private void createFile(ResponsePathDto modelInfo) { private void createFile(ResponsePathDto modelInfo) {
@@ -130,59 +131,42 @@ public class ModelTestMetricsJobService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
// 패키징할 파일 만들기 // 패키징 시작
modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now()); modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now());
ModelMetricJsonDto jsonDto = Path responsePath = Paths.get(responseDir, modelInfo.getUuid().toString());
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try { try {
zipFiles(files, zipPath); // 1. 모든 Best PTH 파일 찾기
List<BestPthInfo> bestPthFiles = findAllBestPthFiles(responsePath);
if (bestPthFiles.isEmpty()) {
log.warn(
"Best PTH 파일을 찾을 수 없습니다: modelId={}, uuid={}",
modelInfo.getModelId(),
modelInfo.getUuid());
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
return;
}
log.info(
"Best PTH 파일 {}개 발견: modelId={}, files={}",
bestPthFiles.size(),
modelInfo.getModelId(),
bestPthFiles.stream().map(BestPthInfo::getFileName).collect(Collectors.toList()));
// 2. 각 Best PTH별로 개별 ZIP 생성
createIndividualZipFiles(modelInfo, bestPthFiles, responsePath);
modelTestMetricsJobCoreService.updatePackingEnd( modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId()); modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
} catch (IOException e) {
log.info(
"모든 ZIP 파일 생성 완료: modelId={}, zipCount={}", modelInfo.getModelId(), bestPthFiles.size());
} catch (Exception e) {
log.error("ZIP 파일 생성 실패: modelId={}", modelInfo.getModelId(), e);
modelTestMetricsJobCoreService.updatePackingEnd( modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId()); modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e); throw new RuntimeException(e);
@@ -227,4 +211,130 @@ public class ModelTestMetricsJobService {
} }
} }
} }
/**
* Response 폴더에서 모든 best_changed_*.pth 파일 찾기
*
* @param responsePath Response 디렉토리 경로
* @return Best PTH 파일 정보 리스트
*/
private List<BestPthInfo> findAllBestPthFiles(Path responsePath) throws IOException {
List<BestPthInfo> bestFiles = new ArrayList<>();
Pattern pattern = Pattern.compile("best_changed_(fscore|precision|recall)_epoch_(\\d+)\\.pth");
try (Stream<Path> stream = Files.list(responsePath)) {
stream
.filter(Files::isRegularFile)
.forEach(
file -> {
String fileName = file.getFileName().toString();
Matcher matcher = pattern.matcher(fileName);
if (matcher.matches()) {
String metricType = matcher.group(1); // fscore, precision, recall
Integer epoch = Integer.parseInt(matcher.group(2));
bestFiles.add(new BestPthInfo(fileName, metricType, epoch, file));
}
});
}
return bestFiles;
}
/**
* 각 Best PTH 파일별로 개별 ZIP 파일 생성
*
* @param modelInfo 모델 정보
* @param bestPthFiles Best PTH 파일 리스트
* @param responsePath Response 디렉토리 경로
*/
private void createIndividualZipFiles(
ResponsePathDto modelInfo, List<BestPthInfo> bestPthFiles, Path responsePath) {
// 공통 파일 경로
Path modelConfigPath = responsePath.resolve("model_config.py");
Path ptPath = findPretrainedModel();
// 공통 파일 존재 확인
if (!Files.exists(modelConfigPath)) {
log.error("model_config.py 파일이 존재하지 않습니다: {}", modelConfigPath);
throw new RuntimeException("필수 파일 누락: model_config.py");
}
if (ptPath == null) {
log.error("사전학습 모델(.pt) 파일을 찾을 수 없습니다.");
throw new RuntimeException("필수 파일 누락: .pt");
}
for (BestPthInfo bestPth : bestPthFiles) {
try {
// 1. 메트릭 JSON 생성
ModelMetricJsonDto metricJson =
modelTestMetricsJobCoreService.getMetricsByEpoch(
modelInfo.getModelId(), bestPth.getEpoch(), bestPth.getMetricType());
if (metricJson == null) {
log.warn(
"메트릭 정보 없음: modelId={}, epoch={}, metricType={}",
modelInfo.getModelId(),
bestPth.getEpoch(),
bestPth.getMetricType());
continue;
}
String jsonFileName = bestPth.getMetricType() + "_metrics.json";
Path jsonPath = responsePath.resolve(jsonFileName);
writeJsonFile(metricJson, jsonPath);
// 2. ZIP에 포함할 파일 리스트
List<Path> filesToZip = new ArrayList<>();
filesToZip.add(modelConfigPath); // model_config.py
filesToZip.add(bestPth.getFilePath()); // best_changed_{type}_epoch_{n}.pth
filesToZip.add(jsonPath); // {type}_metrics.json
filesToZip.add(ptPath); // yolov8_6th-6m.pt
// 3. ZIP 파일명 생성
String zipFileName =
String.format(
"%s.%s.%s.%s.zip",
metricJson.getCdModelType(), // G1
metricJson.getModelVersion(), // G1_000001
modelInfo.getUuid(), // uuid
bestPth.getMetricType()); // fscore/precision/recall
Path zipPath = responsePath.resolve(zipFileName);
// 4. ZIP 압축
zipFiles(filesToZip, zipPath);
log.info("ZIP 파일 생성 완료: {}", zipPath);
} catch (Exception e) {
log.error(
"ZIP 파일 생성 실패: metricType={}, epoch={}",
bestPth.getMetricType(),
bestPth.getEpoch(),
e);
throw new RuntimeException(e);
}
}
}
/**
* 사전학습 모델(PT 파일) 찾기
*
* @return PT 파일 경로
*/
private Path findPretrainedModel() {
try (Stream<Path> stream = Files.list(Path.of(ptPathDir))) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().endsWith(".pt"))
.findFirst()
.orElse(null);
} catch (IOException e) {
log.error("사전학습 모델 찾기 실패", e);
return null;
}
}
} }