daniel 작업 - best epoch의 pth 만큼 zip파일 생성
This commit is contained in:
@@ -369,9 +369,8 @@ public class ModelTrainDetailApiController {
|
||||
}
|
||||
|
||||
@Operation(
|
||||
summary = "학습 결과 ZIP 파일 전체 다운로드",
|
||||
description =
|
||||
"모델 UUID에 해당하는 모든 ZIP 파일 목록을 조회하고" + "생성된 모든 학습데이터의 ZIP 파일을 하나의 ZIP 파일로 압축하여 다운로드",
|
||||
summary = "학습 결과 ZIP 파일 목록 조회",
|
||||
description = "모델 UUID에 해당하는 모든 ZIP 파일 목록과 개별 다운로드 링크 반환",
|
||||
parameters = {
|
||||
@Parameter(
|
||||
name = "kamco-download-uuid",
|
||||
@@ -388,11 +387,12 @@ public class ModelTrainDetailApiController {
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "학습데이터 zip파일 다운로드",
|
||||
description = "ZIP 파일 목록 조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/octet-stream",
|
||||
schema = @Schema(type = "string", format = "binary"))),
|
||||
mediaType = "application/json",
|
||||
schema =
|
||||
@Schema(implementation = ModelTrainDetailDto.ZipFileListResponse.class))),
|
||||
@ApiResponse(responseCode = "404", description = "모델 또는 파일 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@@ -403,6 +403,7 @@ public class ModelTrainDetailApiController {
|
||||
HttpServletRequest request)
|
||||
throws IOException {
|
||||
|
||||
return modelTrainDetailService.downloadZipFile(uuid, downloadUuid, request);
|
||||
return ResponseEntity.ok(
|
||||
modelTrainDetailService.getZipFileListWithFullUrl(uuid, downloadUuid, request));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,6 +444,119 @@ public class ModelTrainDetailService {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 UUID로 ZIP 파일 목록 조회 및 전체 URL 다운로드 링크 생성
|
||||
*
|
||||
* @param uuid 모델 UUID
|
||||
* @param downloadUuid 다운로드 추적 UUID
|
||||
* @param request HTTP 요청
|
||||
* @return ZIP 파일 목록 및 전체 URL 다운로드 링크
|
||||
*/
|
||||
public ZipFileListResponse getZipFileListWithFullUrl(
|
||||
UUID uuid, String downloadUuid, HttpServletRequest request) {
|
||||
log.info("ZIP 파일 목록 조회 시작 (전체 URL): modelUuid={}, downloadUuid={}", uuid, downloadUuid);
|
||||
|
||||
// 1. 모델 정보 조회
|
||||
Basic modelInfo;
|
||||
try {
|
||||
modelInfo = findByModelByUUID(uuid);
|
||||
if (modelInfo == null) {
|
||||
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid);
|
||||
}
|
||||
} catch (NullPointerException e) {
|
||||
log.error("모델 조회 실패: {}", uuid, e);
|
||||
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델을 찾을 수 없습니다: " + uuid);
|
||||
}
|
||||
|
||||
// 2. 실제 디렉토리 경로 찾기 (uuid 또는 uuid-out)
|
||||
Path baseDir = findActualBasePath(uuid);
|
||||
|
||||
if (baseDir == null || !Files.exists(baseDir)) {
|
||||
log.warn(
|
||||
"디렉토리를 찾을 수 없음: modelUuid={}, 시도한 경로: {} 또는 {}-out",
|
||||
uuid,
|
||||
responseDir + "/" + uuid,
|
||||
responseDir + "/" + uuid);
|
||||
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 결과 디렉토리가 존재하지 않습니다.");
|
||||
}
|
||||
|
||||
log.info("디렉토리 발견: basePath={}", baseDir.toString());
|
||||
|
||||
// 요청에서 도메인 정보 추출
|
||||
String scheme = request.getScheme(); // http 또는 https
|
||||
String serverName = request.getServerName(); // localhost 또는 도메인
|
||||
int serverPort = request.getServerPort(); // 8080 등
|
||||
String contextPath = request.getContextPath(); // /api 등
|
||||
|
||||
// 기본 URL 구성
|
||||
String baseUrl;
|
||||
if ((scheme.equals("http") && serverPort == 80)
|
||||
|| (scheme.equals("https") && serverPort == 443)) {
|
||||
baseUrl = scheme + "://" + serverName + contextPath;
|
||||
} else {
|
||||
baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath;
|
||||
}
|
||||
|
||||
// 3. ZIP 파일 목록 검색
|
||||
List<ZipFileInfo> zipFiles = new ArrayList<>();
|
||||
long totalSize = 0L;
|
||||
|
||||
try (Stream<Path> stream = Files.list(baseDir)) {
|
||||
List<Path> files =
|
||||
stream
|
||||
.filter(Files::isRegularFile)
|
||||
.filter(p -> p.getFileName().toString().endsWith(".zip"))
|
||||
.filter(p -> p.getFileName().toString().contains(uuid.toString()))
|
||||
.sorted(
|
||||
Comparator.comparing(p -> p.getFileName().toString(), Comparator.reverseOrder()))
|
||||
.toList();
|
||||
|
||||
for (Path file : files) {
|
||||
String fileName = file.getFileName().toString();
|
||||
long fileSize = Files.size(file);
|
||||
totalSize += fileSize;
|
||||
|
||||
// 파일명에서 버전 추출
|
||||
String version = extractVersionFromZipFileName(fileName);
|
||||
boolean isCurrent = version.equals(modelInfo.getModelVer());
|
||||
|
||||
// 전체 도메인 URL 포함한 다운로드 링크 생성
|
||||
String downloadUrl = baseUrl + "/api/models/download/" + uuid + "?file=" + fileName;
|
||||
|
||||
zipFiles.add(
|
||||
ZipFileInfo.builder()
|
||||
.fileName(fileName)
|
||||
.filePath(file.toString())
|
||||
.version(version)
|
||||
.fileSize(fileSize)
|
||||
.fileSizeFormatted(formatFileSize(fileSize))
|
||||
.lastModified(
|
||||
Files.getLastModifiedTime(file).toInstant().atZone(ZoneId.systemDefault()))
|
||||
.isCurrent(isCurrent)
|
||||
.downloadUrl(downloadUrl)
|
||||
.build());
|
||||
}
|
||||
|
||||
log.info("ZIP 파일 {}개 발견", zipFiles.size());
|
||||
|
||||
} catch (IOException e) {
|
||||
log.error("ZIP 파일 목록 조회 실패: {}", baseDir, e);
|
||||
throw new CustomApiException(
|
||||
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "ZIP 파일 목록 조회 실패");
|
||||
}
|
||||
|
||||
return ZipFileListResponse.builder()
|
||||
.modelUuid(uuid.toString())
|
||||
.modelNo(modelInfo.getModelNo())
|
||||
.modelVer(modelInfo.getModelVer())
|
||||
.basePath(baseDir.toString())
|
||||
.zipFiles(zipFiles)
|
||||
.totalFiles(zipFiles.size())
|
||||
.totalSize(totalSize)
|
||||
.totalSizeFormatted(formatFileSize(totalSize))
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 UUID로 ZIP 파일 목록 조회 및 다운로드 링크 생성
|
||||
*
|
||||
@@ -608,16 +721,24 @@ public class ModelTrainDetailService {
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 UUID로 모든 ZIP 파일 다운로드 (여러 파일이 있으면 하나의 zip으로 묶어서)
|
||||
* 모델 UUID로 모든 ZIP 파일 다운로드 또는 목록 조회
|
||||
*
|
||||
* @param uuid 모델 UUID
|
||||
* @param downloadUuid 다운로드 추적 UUID
|
||||
* @param fileName 개별 파일명 (선택)
|
||||
* @param accept Accept 헤더 (application/json 또는 application/octet-stream)
|
||||
* @param request HTTP 요청
|
||||
* @return ZIP 파일 다운로드 응답
|
||||
* @return JSON 응답 또는 ZIP 파일 Binary 응답
|
||||
*/
|
||||
public ResponseEntity<?> downloadZipFile(
|
||||
UUID uuid, String downloadUuid, HttpServletRequest request) throws IOException {
|
||||
log.info("ZIP 파일 다운로드 시작: modelUuid={}, downloadUuid={}", uuid, downloadUuid);
|
||||
UUID uuid, String downloadUuid, String fileName, String accept, HttpServletRequest request)
|
||||
throws IOException {
|
||||
log.info(
|
||||
"ZIP 파일 다운로드/조회 시작: modelUuid={}, downloadUuid={}, file={}, accept={}",
|
||||
uuid,
|
||||
downloadUuid,
|
||||
fileName,
|
||||
accept);
|
||||
|
||||
// 1. 모델 정보 조회
|
||||
Basic modelInfo;
|
||||
@@ -659,7 +780,134 @@ public class ModelTrainDetailService {
|
||||
baseDir,
|
||||
zipFiles.stream().map(p -> p.getFileName().toString()).toList());
|
||||
|
||||
// 4. 파일이 1개면 바로 다운로드
|
||||
// 4. Accept 헤더에 따라 분기
|
||||
if (accept != null && accept.contains("application/json")) {
|
||||
// JSON 응답: ZIP 파일 목록과 다운로드 링크 반환
|
||||
log.info("JSON 응답 모드: ZIP 파일 목록 반환");
|
||||
return ResponseEntity.ok(buildZipFileListResponse(modelInfo, baseDir, zipFiles, request));
|
||||
}
|
||||
|
||||
// 5. Binary 응답: 파일 다운로드
|
||||
return handleBinaryDownload(modelInfo, baseDir, zipFiles, fileName, uuid, request);
|
||||
}
|
||||
|
||||
/**
|
||||
* ZIP 파일 목록 응답 생성 (JSON)
|
||||
*
|
||||
* @param modelInfo 모델 정보
|
||||
* @param baseDir 기본 디렉토리
|
||||
* @param zipFiles ZIP 파일 목록
|
||||
* @param request HTTP 요청
|
||||
* @return ZipFileListResponse
|
||||
*/
|
||||
private ZipFileListResponse buildZipFileListResponse(
|
||||
Basic modelInfo, Path baseDir, List<Path> zipFiles, HttpServletRequest request)
|
||||
throws IOException {
|
||||
|
||||
// 요청에서 도메인 정보 추출
|
||||
String scheme = request.getScheme(); // http 또는 https
|
||||
String serverName = request.getServerName(); // localhost 또는 도메인
|
||||
int serverPort = request.getServerPort(); // 8080 등
|
||||
String contextPath = request.getContextPath(); // /api 등
|
||||
|
||||
// 기본 URL 구성
|
||||
String baseUrl;
|
||||
if ((scheme.equals("http") && serverPort == 80)
|
||||
|| (scheme.equals("https") && serverPort == 443)) {
|
||||
baseUrl = scheme + "://" + serverName + contextPath;
|
||||
} else {
|
||||
baseUrl = scheme + "://" + serverName + ":" + serverPort + contextPath;
|
||||
}
|
||||
|
||||
List<ZipFileInfo> zipFileInfos = new ArrayList<>();
|
||||
long totalSize = 0L;
|
||||
|
||||
for (Path zipFile : zipFiles) {
|
||||
String zipFileName = zipFile.getFileName().toString();
|
||||
long fileSize = Files.size(zipFile);
|
||||
totalSize += fileSize;
|
||||
|
||||
// 파일명에서 버전 추출
|
||||
String version = extractVersionFromZipFileName(zipFileName);
|
||||
boolean isCurrent = version.equals(modelInfo.getModelVer());
|
||||
|
||||
// 전체 도메인 URL 포함한 다운로드 링크 생성
|
||||
String downloadUrl =
|
||||
baseUrl + "/api/models/downloadzip/" + modelInfo.getUuid() + "?file=" + zipFileName;
|
||||
|
||||
zipFileInfos.add(
|
||||
ZipFileInfo.builder()
|
||||
.fileName(zipFileName)
|
||||
.filePath(zipFile.toString())
|
||||
.version(version)
|
||||
.fileSize(fileSize)
|
||||
.fileSizeFormatted(formatFileSize(fileSize))
|
||||
.lastModified(
|
||||
Files.getLastModifiedTime(zipFile).toInstant().atZone(ZoneId.systemDefault()))
|
||||
.isCurrent(isCurrent)
|
||||
.downloadUrl(downloadUrl)
|
||||
.build());
|
||||
}
|
||||
|
||||
return ZipFileListResponse.builder()
|
||||
.modelUuid(modelInfo.getUuid().toString())
|
||||
.modelNo(modelInfo.getModelNo())
|
||||
.modelVer(modelInfo.getModelVer())
|
||||
.basePath(baseDir.toString())
|
||||
.zipFiles(zipFileInfos)
|
||||
.totalFiles(zipFileInfos.size())
|
||||
.totalSize(totalSize)
|
||||
.totalSizeFormatted(formatFileSize(totalSize))
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary 다운로드 처리
|
||||
*
|
||||
* @param modelInfo 모델 정보
|
||||
* @param baseDir 기본 디렉토리
|
||||
* @param zipFiles ZIP 파일 목록
|
||||
* @param fileName 개별 파일명 (선택)
|
||||
* @param uuid 모델 UUID
|
||||
* @param request HTTP 요청
|
||||
* @return Binary 응답
|
||||
*/
|
||||
private ResponseEntity<?> handleBinaryDownload(
|
||||
Basic modelInfo,
|
||||
Path baseDir,
|
||||
List<Path> zipFiles,
|
||||
String fileName,
|
||||
UUID uuid,
|
||||
HttpServletRequest request)
|
||||
throws IOException {
|
||||
|
||||
// file 파라미터가 있으면 개별 파일 다운로드
|
||||
if (fileName != null && !fileName.isEmpty()) {
|
||||
log.info("개별 파일 다운로드 요청: fileName={}", fileName);
|
||||
|
||||
Path targetFile =
|
||||
zipFiles.stream()
|
||||
.filter(p -> p.getFileName().toString().equals(fileName))
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
if (targetFile == null) {
|
||||
log.warn("요청한 파일을 찾을 수 없음: fileName={}, basePath={}", fileName, baseDir);
|
||||
throw new CustomApiException(
|
||||
"NOT_FOUND", HttpStatus.NOT_FOUND, "요청한 파일을 찾을 수 없습니다: " + fileName);
|
||||
}
|
||||
|
||||
log.info(
|
||||
"개별 ZIP 파일 다운로드: fileName={}, fileSize={} bytes, basePath={}",
|
||||
targetFile.getFileName(),
|
||||
Files.size(targetFile),
|
||||
baseDir);
|
||||
|
||||
return rangeDownloadResponder.buildZipResponse(targetFile, fileName, request);
|
||||
}
|
||||
|
||||
// file 파라미터가 없으면 모든 파일 다운로드
|
||||
// 파일이 1개면 바로 다운로드
|
||||
if (zipFiles.size() == 1) {
|
||||
Path zipPath = zipFiles.get(0);
|
||||
log.info(
|
||||
@@ -671,7 +919,7 @@ public class ModelTrainDetailService {
|
||||
return rangeDownloadResponder.buildZipResponse(zipPath, downloadFileName, request);
|
||||
}
|
||||
|
||||
// 5. 파일이 여러 개면 하나의 zip으로 묶어서 다운로드
|
||||
// 파일이 여러 개면 하나의 zip으로 묶어서 다운로드
|
||||
log.info("여러 ZIP 파일을 하나로 묶어서 다운로드: 총 {}개 파일", zipFiles.size());
|
||||
|
||||
String combinedZipName = modelInfo.getModelNo() + "." + modelInfo.getModelVer() + ".all.zip";
|
||||
|
||||
@@ -51,4 +51,16 @@ public class ModelTestMetricsJobCoreService {
|
||||
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
|
||||
modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState);
|
||||
}
|
||||
|
||||
/**
|
||||
* 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용)
|
||||
*
|
||||
* @param modelId 모델 ID
|
||||
* @param epoch Epoch 번호
|
||||
* @param metricType 메트릭 타입 (fscore, precision, recall)
|
||||
* @return 메트릭 JSON DTO
|
||||
*/
|
||||
public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) {
|
||||
return modelTestMetricsJobRepository.getMetricsByEpoch(modelId, epoch, metricType);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,4 +23,14 @@ public interface ModelTestMetricsJobRepositoryCustom {
|
||||
void updatePackingStart(Long modelId, ZonedDateTime now);
|
||||
|
||||
void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState);
|
||||
|
||||
/**
|
||||
* 특정 Epoch의 메트릭 정보 조회 (하이퍼파라미터별 ZIP 생성용)
|
||||
*
|
||||
* @param modelId 모델 ID
|
||||
* @param epoch Epoch 번호
|
||||
* @param metricType 메트릭 타입 (fscore, precision, recall)
|
||||
* @return 메트릭 JSON DTO
|
||||
*/
|
||||
ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.kamco.cd.training.postgres.repository.train;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
@@ -187,4 +188,49 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.execute();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelMetricJsonDto getMetricsByEpoch(Long modelId, Integer epoch, String metricType) {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelMetricJsonDto.class,
|
||||
modelMasterEntity.modelNo,
|
||||
modelMasterEntity.modelVer,
|
||||
com.querydsl.core.types.dsl.Expressions.constant(metricType),
|
||||
com.querydsl.core.types.dsl.Expressions.constant(epoch),
|
||||
Projections.constructor(
|
||||
Properties.class,
|
||||
// Changed 메트릭
|
||||
modelMetricsValidationEntity.changedFscore,
|
||||
modelMetricsValidationEntity.changedPrecision,
|
||||
modelMetricsValidationEntity.changedRecall,
|
||||
// Unchanged 메트릭
|
||||
modelMetricsValidationEntity.unchangedFscore,
|
||||
modelMetricsValidationEntity.unchangedPrecision,
|
||||
modelMetricsValidationEntity.unchangedRecall,
|
||||
// Mean 메트릭
|
||||
modelMetricsValidationEntity.mFscore,
|
||||
modelMetricsValidationEntity.mPrecision,
|
||||
modelMetricsValidationEntity.mRecall,
|
||||
modelMetricsValidationEntity.mIou,
|
||||
modelMetricsValidationEntity.mAcc,
|
||||
// Overall Accuracy
|
||||
modelMetricsValidationEntity.aAcc,
|
||||
// Train 메트릭 (Loss, LR, Duration)
|
||||
modelMetricsTrainEntity.loss,
|
||||
modelMetricsTrainEntity.lr,
|
||||
modelMetricsTrainEntity.durationTime)))
|
||||
.from(modelMetricsValidationEntity)
|
||||
.innerJoin(modelMasterEntity)
|
||||
.on(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.id))
|
||||
.leftJoin(modelMetricsTrainEntity)
|
||||
.on(
|
||||
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.id),
|
||||
modelMetricsTrainEntity.epoch.eq(epoch))
|
||||
.where(
|
||||
modelMetricsValidationEntity.model.id.eq(modelId),
|
||||
modelMetricsValidationEntity.epoch.eq(epoch))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.train.dto;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Properties;
|
||||
import java.util.UUID;
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -23,6 +24,17 @@ public class ModelTrainMetricsDto {
|
||||
private UUID uuid;
|
||||
}
|
||||
|
||||
@Schema(name = "BestPthInfo", description = "Best PTH 파일 정보")
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class BestPthInfo {
|
||||
|
||||
private String fileName; // best_changed_fscore_epoch_3.pth
|
||||
private String metricType; // fscore, precision, recall
|
||||
private Integer epoch; // 3
|
||||
private Path filePath; // 파일 전체 경로
|
||||
}
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class ModelMetricJsonDto {
|
||||
@@ -33,6 +45,11 @@ public class ModelTrainMetricsDto {
|
||||
@JsonProperty("model_version")
|
||||
private String modelVersion;
|
||||
|
||||
@JsonProperty("metric_type")
|
||||
private String metricType; // fscore, precision, recall
|
||||
|
||||
private Integer epoch; // epoch 번호
|
||||
|
||||
private Properties properties;
|
||||
}
|
||||
|
||||
@@ -40,13 +57,55 @@ public class ModelTrainMetricsDto {
|
||||
@AllArgsConstructor
|
||||
public static class Properties {
|
||||
|
||||
@JsonProperty("f1_score")
|
||||
private Float f1Score;
|
||||
// 변화 탐지 관련 메트릭 (Changed)
|
||||
@JsonProperty("changed_fscore")
|
||||
private Float changedFscore;
|
||||
|
||||
private Float precision;
|
||||
private Float recall;
|
||||
private Float loss;
|
||||
private Double iou;
|
||||
@JsonProperty("changed_precision")
|
||||
private Float changedPrecision;
|
||||
|
||||
@JsonProperty("changed_recall")
|
||||
private Float changedRecall;
|
||||
|
||||
// 비변화 관련 메트릭 (Unchanged)
|
||||
@JsonProperty("unchanged_fscore")
|
||||
private Float unchangedFscore;
|
||||
|
||||
@JsonProperty("unchanged_precision")
|
||||
private Float unchangedPrecision;
|
||||
|
||||
@JsonProperty("unchanged_recall")
|
||||
private Float unchangedRecall;
|
||||
|
||||
// 평균 메트릭 (Mean)
|
||||
@JsonProperty("mean_fscore")
|
||||
private Float mFscore;
|
||||
|
||||
@JsonProperty("mean_precision")
|
||||
private Float mPrecision;
|
||||
|
||||
@JsonProperty("mean_recall")
|
||||
private Float mRecall;
|
||||
|
||||
@JsonProperty("mean_iou")
|
||||
private Float mIou;
|
||||
|
||||
@JsonProperty("mean_accuracy")
|
||||
private Float mAcc;
|
||||
|
||||
// 전체 정확도 (Overall Accuracy)
|
||||
@JsonProperty("overall_accuracy")
|
||||
private Float aAcc;
|
||||
|
||||
// 학습 관련 메트릭
|
||||
@JsonProperty("loss")
|
||||
private Double loss;
|
||||
|
||||
@JsonProperty("learning_rate")
|
||||
private Double lr;
|
||||
|
||||
@JsonProperty("duration_time")
|
||||
private Float durationTime;
|
||||
}
|
||||
|
||||
@Getter
|
||||
|
||||
@@ -4,8 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.BestPthInfo;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
@@ -18,7 +18,8 @@ import java.nio.file.StandardOpenOption;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.ZipEntry;
|
||||
@@ -76,9 +77,9 @@ public class ModelTestMetricsJobService {
|
||||
}
|
||||
|
||||
/**
|
||||
* 베스트 에폭 zip파일 생성, 테스트결과 db등록
|
||||
* 베스트 에폭별 개별 ZIP 파일 생성, 테스트결과 db등록
|
||||
*
|
||||
* @param modelInfo
|
||||
* @param modelInfo 모델 정보
|
||||
*/
|
||||
private void createFile(ResponsePathDto modelInfo) {
|
||||
|
||||
@@ -130,59 +131,42 @@ public class ModelTestMetricsJobService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 패키징할 파일 만들기
|
||||
// 패키징 시작
|
||||
modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now());
|
||||
|
||||
ModelMetricJsonDto jsonDto =
|
||||
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
|
||||
try {
|
||||
writeJsonFile(
|
||||
jsonDto,
|
||||
Paths.get(
|
||||
responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json"));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||
|
||||
ModelTestFileName fileInfo =
|
||||
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
|
||||
|
||||
Path zipPath =
|
||||
Paths.get(
|
||||
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
|
||||
Set<String> targetNames =
|
||||
Set.of(
|
||||
"model_config.py",
|
||||
fileInfo.getBestEpochFileName() + ".pth",
|
||||
fileInfo.getModelVersion() + ".json");
|
||||
|
||||
List<Path> files = new ArrayList<>();
|
||||
try (Stream<Path> s = Files.list(responsePath)) {
|
||||
files.addAll(
|
||||
s.filter(Files::isRegularFile)
|
||||
.filter(p -> targetNames.contains(p.getFileName().toString()))
|
||||
.collect(Collectors.toList()));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
|
||||
files.addAll(
|
||||
s.filter(Files::isRegularFile)
|
||||
.limit(1) // yolov8_6th-6m.pt 파일 1개만
|
||||
.collect(Collectors.toList()));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
Path responsePath = Paths.get(responseDir, modelInfo.getUuid().toString());
|
||||
|
||||
try {
|
||||
zipFiles(files, zipPath);
|
||||
// 1. 모든 Best PTH 파일 찾기
|
||||
List<BestPthInfo> bestPthFiles = findAllBestPthFiles(responsePath);
|
||||
|
||||
if (bestPthFiles.isEmpty()) {
|
||||
log.warn(
|
||||
"Best PTH 파일을 찾을 수 없습니다: modelId={}, uuid={}",
|
||||
modelInfo.getModelId(),
|
||||
modelInfo.getUuid());
|
||||
modelTestMetricsJobCoreService.updatePackingEnd(
|
||||
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
|
||||
return;
|
||||
}
|
||||
|
||||
log.info(
|
||||
"Best PTH 파일 {}개 발견: modelId={}, files={}",
|
||||
bestPthFiles.size(),
|
||||
modelInfo.getModelId(),
|
||||
bestPthFiles.stream().map(BestPthInfo::getFileName).collect(Collectors.toList()));
|
||||
|
||||
// 2. 각 Best PTH별로 개별 ZIP 생성
|
||||
createIndividualZipFiles(modelInfo, bestPthFiles, responsePath);
|
||||
|
||||
modelTestMetricsJobCoreService.updatePackingEnd(
|
||||
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
|
||||
} catch (IOException e) {
|
||||
|
||||
log.info(
|
||||
"모든 ZIP 파일 생성 완료: modelId={}, zipCount={}", modelInfo.getModelId(), bestPthFiles.size());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("ZIP 파일 생성 실패: modelId={}", modelInfo.getModelId(), e);
|
||||
modelTestMetricsJobCoreService.updatePackingEnd(
|
||||
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
|
||||
throw new RuntimeException(e);
|
||||
@@ -227,4 +211,130 @@ public class ModelTestMetricsJobService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Response 폴더에서 모든 best_changed_*.pth 파일 찾기
|
||||
*
|
||||
* @param responsePath Response 디렉토리 경로
|
||||
* @return Best PTH 파일 정보 리스트
|
||||
*/
|
||||
private List<BestPthInfo> findAllBestPthFiles(Path responsePath) throws IOException {
|
||||
List<BestPthInfo> bestFiles = new ArrayList<>();
|
||||
Pattern pattern = Pattern.compile("best_changed_(fscore|precision|recall)_epoch_(\\d+)\\.pth");
|
||||
|
||||
try (Stream<Path> stream = Files.list(responsePath)) {
|
||||
stream
|
||||
.filter(Files::isRegularFile)
|
||||
.forEach(
|
||||
file -> {
|
||||
String fileName = file.getFileName().toString();
|
||||
Matcher matcher = pattern.matcher(fileName);
|
||||
if (matcher.matches()) {
|
||||
String metricType = matcher.group(1); // fscore, precision, recall
|
||||
Integer epoch = Integer.parseInt(matcher.group(2));
|
||||
|
||||
bestFiles.add(new BestPthInfo(fileName, metricType, epoch, file));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return bestFiles;
|
||||
}
|
||||
|
||||
/**
|
||||
* 각 Best PTH 파일별로 개별 ZIP 파일 생성
|
||||
*
|
||||
* @param modelInfo 모델 정보
|
||||
* @param bestPthFiles Best PTH 파일 리스트
|
||||
* @param responsePath Response 디렉토리 경로
|
||||
*/
|
||||
private void createIndividualZipFiles(
|
||||
ResponsePathDto modelInfo, List<BestPthInfo> bestPthFiles, Path responsePath) {
|
||||
|
||||
// 공통 파일 경로
|
||||
Path modelConfigPath = responsePath.resolve("model_config.py");
|
||||
Path ptPath = findPretrainedModel();
|
||||
|
||||
// 공통 파일 존재 확인
|
||||
if (!Files.exists(modelConfigPath)) {
|
||||
log.error("model_config.py 파일이 존재하지 않습니다: {}", modelConfigPath);
|
||||
throw new RuntimeException("필수 파일 누락: model_config.py");
|
||||
}
|
||||
|
||||
if (ptPath == null) {
|
||||
log.error("사전학습 모델(.pt) 파일을 찾을 수 없습니다.");
|
||||
throw new RuntimeException("필수 파일 누락: .pt");
|
||||
}
|
||||
|
||||
for (BestPthInfo bestPth : bestPthFiles) {
|
||||
try {
|
||||
// 1. 메트릭 JSON 생성
|
||||
ModelMetricJsonDto metricJson =
|
||||
modelTestMetricsJobCoreService.getMetricsByEpoch(
|
||||
modelInfo.getModelId(), bestPth.getEpoch(), bestPth.getMetricType());
|
||||
|
||||
if (metricJson == null) {
|
||||
log.warn(
|
||||
"메트릭 정보 없음: modelId={}, epoch={}, metricType={}",
|
||||
modelInfo.getModelId(),
|
||||
bestPth.getEpoch(),
|
||||
bestPth.getMetricType());
|
||||
continue;
|
||||
}
|
||||
|
||||
String jsonFileName = bestPth.getMetricType() + "_metrics.json";
|
||||
Path jsonPath = responsePath.resolve(jsonFileName);
|
||||
writeJsonFile(metricJson, jsonPath);
|
||||
|
||||
// 2. ZIP에 포함할 파일 리스트
|
||||
List<Path> filesToZip = new ArrayList<>();
|
||||
filesToZip.add(modelConfigPath); // model_config.py
|
||||
filesToZip.add(bestPth.getFilePath()); // best_changed_{type}_epoch_{n}.pth
|
||||
filesToZip.add(jsonPath); // {type}_metrics.json
|
||||
filesToZip.add(ptPath); // yolov8_6th-6m.pt
|
||||
|
||||
// 3. ZIP 파일명 생성
|
||||
String zipFileName =
|
||||
String.format(
|
||||
"%s.%s.%s.%s.zip",
|
||||
metricJson.getCdModelType(), // G1
|
||||
metricJson.getModelVersion(), // G1_000001
|
||||
modelInfo.getUuid(), // uuid
|
||||
bestPth.getMetricType()); // fscore/precision/recall
|
||||
|
||||
Path zipPath = responsePath.resolve(zipFileName);
|
||||
|
||||
// 4. ZIP 압축
|
||||
zipFiles(filesToZip, zipPath);
|
||||
|
||||
log.info("ZIP 파일 생성 완료: {}", zipPath);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error(
|
||||
"ZIP 파일 생성 실패: metricType={}, epoch={}",
|
||||
bestPth.getMetricType(),
|
||||
bestPth.getEpoch(),
|
||||
e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 사전학습 모델(PT 파일) 찾기
|
||||
*
|
||||
* @return PT 파일 경로
|
||||
*/
|
||||
private Path findPretrainedModel() {
|
||||
try (Stream<Path> stream = Files.list(Path.of(ptPathDir))) {
|
||||
return stream
|
||||
.filter(Files::isRegularFile)
|
||||
.filter(p -> p.getFileName().toString().endsWith(".pt"))
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
} catch (IOException e) {
|
||||
log.error("사전학습 모델 찾기 실패", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user