test metrics 원복

This commit is contained in:
2026-04-09 15:59:17 +09:00
parent e64b1f15ba
commit ca080bf77b

View File

@@ -4,7 +4,6 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.BestPthInfo;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
@@ -20,8 +19,6 @@ import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipEntry;
@@ -51,6 +48,9 @@ public class ModelTestMetricsJobService {
@Value("${file.pt-path}")
private String ptPathDir;
@Value("${file.pt-FileName}")
private String ptFileName;
/** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() {
@@ -79,9 +79,9 @@ public class ModelTestMetricsJobService {
}
/**
* 베스트 에폭별 개별 ZIP 파일 생성, 테스트결과 db등록
* 베스트 에폭 zip파일 생성, 테스트결과 db등록
*
* @param modelInfo 모델 정보
* @param modelInfo
*/
private void createFile(ResponsePathDto modelInfo) {
@@ -133,51 +133,159 @@ public class ModelTestMetricsJobService {
throw new RuntimeException(e);
}
// 패키징 시작
// 패키징할 파일 만들기
modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now());
Path responsePath = Paths.get(responseDir, modelInfo.getUuid().toString());
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
// 1. 모든 Best PTH 파일 찾기
List<BestPthInfo> bestPthFiles = findAllBestPthFiles(responsePath);
if (bestPthFiles.isEmpty()) {
log.warn(
"Best PTH 파일을 찾을 수 없습니다: modelId={}, uuid={}",
modelInfo.getModelId(),
modelInfo.getUuid());
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
return;
writeJsonFile(
jsonDto,
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
log.info(
"Best PTH 파일 {}개 발견: modelId={}, files={}",
bestPthFiles.size(),
modelInfo.getModelId(),
bestPthFiles.stream().map(BestPthInfo::getFileName).collect(Collectors.toList()));
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
// 2. 기존 방식: F-Score 기준 통합 ZIP 생성 (1개)
createLegacyZipFile(modelInfo, responsePath);
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
// 3. 신규 방식: 각 Best PTH별로 개별 ZIP 생성 (3개)
createIndividualZipFiles(modelInfo, bestPthFiles, responsePath);
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
zipFiles(files, zipPath);
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
log.info(
"모든 ZIP 파일 생성 완료: modelId={}, 기존방식=1개, 개별방식={}개",
modelInfo.getModelId(),
bestPthFiles.size());
} catch (Exception e) {
log.error("ZIP 파일 생성 실패: modelId={}", modelInfo.getModelId(), e);
} catch (IOException e) {
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e);
}
// ===== 추가: 각 best*.pth 파일별 개별 ZIP 생성 =====
createIndividualBestPthZips(modelInfo, responsePath, jsonDto);
}
/**
* Response 폴더의 모든 best*.pth 파일을 각각 개별 ZIP 파일로 생성
*
* @param modelInfo 모델 정보
* @param responsePath Response 디렉토리 경로
* @param jsonDto JSON 메타데이터
*/
private void createIndividualBestPthZips(
ResponsePathDto modelInfo, Path responsePath, ModelMetricJsonDto jsonDto) {
log.info("=== 개별 best*.pth ZIP 파일 생성 시작: modelId={} ===", modelInfo.getModelId());
try {
// 1. Response 폴더에서 모든 best*.pth 파일 찾기
List<Path> bestPthFiles;
try (Stream<Path> stream = Files.list(responsePath)) {
bestPthFiles =
stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().startsWith("best"))
.filter(p -> p.getFileName().toString().endsWith(".pth"))
.collect(Collectors.toList());
}
if (bestPthFiles.isEmpty()) {
log.warn("best*.pth 파일을 찾을 수 없습니다: path={}", responsePath);
return;
}
log.info("발견된 best*.pth 파일 개수: {}", bestPthFiles.size());
// 2. PT 파일 경로 (모든 ZIP에 공통으로 포함)
Path ptFile = Paths.get(ptPathDir, ptFileName);
if (!Files.exists(ptFile)) {
log.warn("PT 파일을 찾을 수 없습니다: {}", ptFile);
return;
}
// 3. 각 best*.pth 파일별로 개별 ZIP 생성
for (Path bestPthFile : bestPthFiles) {
String pthFileName = bestPthFile.getFileName().toString();
log.info("처리 중인 best PTH 파일: {}", pthFileName);
try {
// 3-1. 개별 JSON 파일 생성
String individualJsonName = pthFileName.replace(".pth", ".json");
Path individualJsonPath = responsePath.resolve(individualJsonName);
writeJsonFile(jsonDto, individualJsonPath);
log.info("개별 JSON 생성: {}", individualJsonName);
// 3-2. 개별 ZIP 파일명 생성
// 형식: {modelVersion}.{pthFileNameWithoutExt}.zip
// 예: G1_000001.best_epoch_3.zip
String pthFileNameWithoutExt = pthFileName.replace(".pth", "");
String individualZipName =
jsonDto.getModelVersion() + "." + pthFileNameWithoutExt + ".zip";
Path individualZipPath = responsePath.resolve(individualZipName);
// 3-3. ZIP에 포함될 파일 목록 구성
List<Path> zipFileList = new ArrayList<>();
zipFileList.add(bestPthFile); // best*.pth 파일
zipFileList.add(individualJsonPath); // 개별 JSON 파일
zipFileList.add(ptFile); // PT 파일
// model_config.py 파일이 있으면 추가
Path modelConfigPath = responsePath.resolve("model_config.py");
if (Files.exists(modelConfigPath)) {
zipFileList.add(modelConfigPath);
}
// 3-4. 개별 ZIP 생성
zipFiles(zipFileList, individualZipPath);
log.info(
"✅ 개별 ZIP 생성 완료: fileName={}, pthFile={}, size={} bytes",
individualZipName,
pthFileName,
Files.size(individualZipPath));
} catch (IOException e) {
log.error("개별 ZIP 생성 실패: pthFile={}", pthFileName, e);
// 개별 ZIP 실패는 전체 프로세스를 중단하지 않음
}
}
log.info("=== 개별 best*.pth ZIP 파일 생성 완료: 총 {}개 ===", bestPthFiles.size());
} catch (IOException e) {
log.error("개별 ZIP 생성 중 오류 발생", e);
// 에러 발생해도 기존 ZIP은 이미 생성되었으므로 예외를 던지지 않음
}
}
private void writeJsonFile(Object data, Path outputPath) throws IOException {
@@ -218,289 +326,4 @@ public class ModelTestMetricsJobService {
}
}
}
/**
* 기존 방식: F-Score 기준 통합 ZIP 파일 생성 파일명: {modelVersion}.zip (예: G4_000001.zip) 포함 파일:
* model_config.py, best_changed_fscore_epoch_N.pth, {modelVersion}.json, yolov8_6th-6m.pt
*
* @param modelInfo 모델 정보
* @param responsePath Response 디렉토리 경로
*/
private void createLegacyZipFile(ResponsePathDto modelInfo, Path responsePath) {
try {
log.info("기존 방식 ZIP 파일 생성 시작: modelId={}", modelInfo.getModelId());
// 1. Test 메트릭 기반 JSON 생성 (기존 getTestMetricPackingInfo 사용)
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
if (jsonDto == null) {
log.warn("Test 메트릭 정보를 찾을 수 없습니다: modelId={}", modelInfo.getModelId());
return;
}
// 2. JSON 파일 생성: {modelVersion}.json (예: G4_000001.json)
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir
+ "/"
+ modelInfo.getUuid()
+ "/"
+ jsonDto.getModelVersion()
+ ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
log.info("JSON 파일 생성 완료: {}.json", jsonDto.getModelVersion());
// 3. Best Epoch 파일명 찾기 (F-Score 기준)
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
if (fileInfo == null || fileInfo.getBestEpochFileName() == null) {
log.warn("Best Epoch 파일명을 찾을 수 없습니다: modelId={}", modelInfo.getModelId());
return;
}
log.info("Best Epoch 파일명: {}.pth", fileInfo.getBestEpochFileName());
// 4. ZIP 파일 경로: {modelVersion}.zip (예: G4_000001.zip)
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
// 5. ZIP에 포함할 파일 리스트
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth", // best_changed_fscore_epoch_N.pth
fileInfo.getModelVersion() + ".json"); // {modelVersion}.json
List<Path> files = new ArrayList<>();
// Response 폴더에서 파일 수집
try (Stream<Path> stream = Files.list(responsePath)) {
files.addAll(
stream
.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
}
log.info("Response 폴더에서 수집한 파일 {}개", files.size());
// PT 파일 추가 (사전학습 모델)
try (Stream<Path> stream = Files.list(Path.of(ptPathDir))) {
files.addAll(
stream
.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
}
log.info("최종 ZIP에 포함할 파일 {}개", files.size());
// 6. ZIP 압축
zipFiles(files, zipPath);
long zipSize = Files.size(zipPath);
log.info("기존 방식 ZIP 파일 생성 완료: fileName={}, size={} bytes", zipPath.getFileName(), zipSize);
} catch (Exception e) {
log.error("기존 방식 ZIP 파일 생성 실패: modelId={}", modelInfo.getModelId(), e);
// 에러가 발생해도 신규 방식 ZIP은 계속 생성되도록 throw 하지 않음
}
}
/**
* Response 폴더에서 모든 best_changed_*.pth 파일 찾기 패턴: best_changed_{metricType}_epoch_{N}.pth 예:
* best_changed_fscore_epoch_3.pth, best_changed_precision_epoch_2.pth,
* best_changed_accuracy_epoch_5.pth 등
*
* @param responsePath Response 디렉토리 경로
* @return Best PTH 파일 정보 리스트
*/
private List<BestPthInfo> findAllBestPthFiles(Path responsePath) throws IOException {
List<BestPthInfo> bestFiles = new ArrayList<>();
// 개선: 모든 메트릭 타입을 자동으로 감지하는 유연한 패턴
// best_changed_{어떤문자든}_epoch_{숫자}.pth 형식의 모든 파일 검색
Pattern pattern = Pattern.compile("best_changed_(.+?)_epoch_(\\d+)\\.pth");
log.info("Best PTH 파일 검색 시작: path={}", responsePath);
log.info("검색 패턴: best_changed_{{metricType}}_epoch_{{N}}.pth");
try (Stream<Path> stream = Files.list(responsePath)) {
stream
.filter(Files::isRegularFile)
.forEach(
file -> {
String fileName = file.getFileName().toString();
Matcher matcher = pattern.matcher(fileName);
if (matcher.matches()) {
String metricType =
matcher.group(1); // 메트릭 타입 (fscore, precision, recall, accuracy 등)
Integer epoch = Integer.parseInt(matcher.group(2));
log.info(
"Best PTH 파일 발견: file={}, metricType={}, epoch={}",
fileName,
metricType,
epoch);
bestFiles.add(new BestPthInfo(fileName, metricType, epoch, file));
}
});
}
log.info("Best PTH 파일 검색 완료: 총 {}개 발견", bestFiles.size());
if (bestFiles.isEmpty()) {
log.warn("Best PTH 파일이 하나도 발견되지 않았습니다. 파일명 패턴을 확인하세요.");
log.warn("예상 패턴: best_changed_{{metricType}}_epoch_{{N}}.pth");
// 디버깅: response 폴더의 모든 .pth 파일 출력
try (Stream<Path> debugStream = Files.list(responsePath)) {
List<String> allPthFiles =
debugStream
.filter(Files::isRegularFile)
.map(p -> p.getFileName().toString())
.filter(name -> name.endsWith(".pth"))
.collect(Collectors.toList());
if (!allPthFiles.isEmpty()) {
log.info("Response 폴더의 .pth 파일 목록:");
allPthFiles.forEach(name -> log.info(" - {}", name));
} else {
log.warn("Response 폴더에 .pth 파일이 전혀 없습니다.");
}
} catch (IOException e) {
log.error("디버깅 중 에러 발생", e);
}
}
return bestFiles;
}
/**
* 각 Best PTH 파일별로 개별 ZIP 파일 생성
*
* @param modelInfo 모델 정보
* @param bestPthFiles Best PTH 파일 리스트
* @param responsePath Response 디렉토리 경로
*/
private void createIndividualZipFiles(
ResponsePathDto modelInfo, List<BestPthInfo> bestPthFiles, Path responsePath) {
// 공통 파일 경로
Path modelConfigPath = responsePath.resolve("model_config.py");
Path ptPath = findPretrainedModel();
// 공통 파일 존재 확인
if (!Files.exists(modelConfigPath)) {
log.error("model_config.py 파일이 존재하지 않습니다: {}", modelConfigPath);
throw new RuntimeException("필수 파일 누락: model_config.py");
}
if (ptPath == null) {
log.error("사전학습 모델(.pt) 파일을 찾을 수 없습니다.");
throw new RuntimeException("필수 파일 누락: .pt");
}
for (BestPthInfo bestPth : bestPthFiles) {
try {
log.info(
"ZIP 파일 생성 시작: modelId={}, metricType={}, epoch={}, pthFile={}",
modelInfo.getModelId(),
bestPth.getMetricType(),
bestPth.getEpoch(),
bestPth.getFileName());
// 1. 메트릭 JSON 생성
ModelMetricJsonDto metricJson =
modelTestMetricsJobCoreService.getMetricsByEpoch(
modelInfo.getModelId(), bestPth.getEpoch(), bestPth.getMetricType());
if (metricJson == null) {
log.warn(
"메트릭 정보 없음: modelId={}, epoch={}, metricType={}",
modelInfo.getModelId(),
bestPth.getEpoch(),
bestPth.getMetricType());
continue;
}
log.info(
"메트릭 JSON 조회 성공: modelId={}, metricType={}, epoch={}, cdModelType={}, modelVersion={}",
modelInfo.getModelId(),
metricJson.getMetricType(),
metricJson.getEpoch(),
metricJson.getCdModelType(),
metricJson.getModelVersion());
String jsonFileName = bestPth.getMetricType() + "_metrics.json";
Path jsonPath = responsePath.resolve(jsonFileName);
writeJsonFile(metricJson, jsonPath);
log.info("JSON 파일 생성 완료: {}", jsonFileName);
// 2. ZIP에 포함할 파일 리스트
List<Path> filesToZip = new ArrayList<>();
filesToZip.add(modelConfigPath); // model_config.py
filesToZip.add(bestPth.getFilePath()); // best_changed_{type}_epoch_{n}.pth
filesToZip.add(jsonPath); // {type}_metrics.json
filesToZip.add(ptPath); // yolov8_6th-6m.pt
// 3. ZIP 파일명 생성 (하이퍼파라미터별 구분)
String zipFileName =
String.format(
"%s.%s.zip",
metricJson.getModelVersion(), // G1_000001
bestPth.getMetricType()); // fscore/precision/recall
Path zipPath = responsePath.resolve(zipFileName);
// 4. ZIP 압축
zipFiles(filesToZip, zipPath);
long zipSize = Files.size(zipPath);
log.info(
"ZIP 파일 생성 완료: path={}, size={} bytes, metricType={}, epoch={}",
zipPath,
zipSize,
bestPth.getMetricType(),
bestPth.getEpoch());
} catch (Exception e) {
log.error(
"ZIP 파일 생성 실패: metricType={}, epoch={}, pthFile={}",
bestPth.getMetricType(),
bestPth.getEpoch(),
bestPth.getFileName(),
e);
throw new RuntimeException(e);
}
}
}
/**
* 사전학습 모델(PT 파일) 찾기
*
* @return PT 파일 경로
*/
private Path findPretrainedModel() {
try (Stream<Path> stream = Files.list(Path.of(ptPathDir))) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().endsWith(".pt"))
.findFirst()
.orElse(null);
} catch (IOException e) {
log.error("사전학습 모델 찾기 실패", e);
return null;
}
}
}