From 2df4a7a80b5a0ebc97c544904602b0c5da36c143 Mon Sep 17 00:00:00 2001 From: "gayoun.park" Date: Thu, 12 Feb 2026 15:24:30 +0900 Subject: [PATCH] =?UTF-8?q?csv=20=ED=8C=8C=EC=9D=BC=20=EC=9D=BD=EB=8A=94?= =?UTF-8?q?=20=EA=B2=BD=EB=A1=9C=20=EC=9D=BD=EC=96=B4=EC=84=9C=20=EC=88=98?= =?UTF-8?q?=EC=A0=95,=20train=EC=9D=80=20epoch=20+=201=20=ED=95=B4?= =?UTF-8?q?=EC=84=9C=20=EC=A0=80=EC=9E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/ModelTestMetricsJobCoreService.java | 3 +- .../core/ModelTrainMetricsJobCoreService.java | 3 +- .../ModelTestMetricsJobRepositoryCustom.java | 3 +- .../ModelTestMetricsJobRepositoryImpl.java | 8 +- .../ModelTrainMetricsJobRepositoryCustom.java | 3 +- .../ModelTrainMetricsJobRepositoryImpl.java | 12 +- .../train/dto/ModelTrainMetricsDto.java | 21 +++ .../service/ModelTestMetricsJobService.java | 94 ++++++------ .../service/ModelTrainMetricsJobService.java | 136 +++++++++--------- 9 files changed, 160 insertions(+), 123 deletions(-) create mode 100644 src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java index 6ba1e56..a19e34c 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.postgres.core; import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.util.List; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; @@ -18,7 +19,7 @@ public class ModelTestMetricsJobCoreService { } // Test 로직 시작 - public List getTestMetricSaveNotYetModelIds() { + public List getTestMetricSaveNotYetModelIds() { return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds(); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java index 5692017..d7823b1 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.postgres.core; import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.util.List; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; @@ -12,7 +13,7 @@ public class ModelTrainMetricsJobCoreService { private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository; - public List getTrainMetricSaveNotYetModelIds() { + public List getTrainMetricSaveNotYetModelIds() { return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds(); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java index bd993e1..a1a6a04 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java @@ -1,12 +1,13 @@ package com.kamco.cd.training.postgres.repository.train; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.util.List; public interface ModelTestMetricsJobRepositoryCustom { void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); - List getTestMetricSaveNotYetModelIds(); + List getTestMetricSaveNotYetModelIds(); void insertModelMetricsTest(List batchArgs); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java index 7804c52..f189150 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java @@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; +import com.querydsl.core.types.Projections; import com.querydsl.jpa.impl.JPAQueryFactory; import java.util.List; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; @@ -36,9 +38,11 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport } @Override - public List getTestMetricSaveNotYetModelIds() { + public List getTestMetricSaveNotYetModelIds() { return queryFactory - .select(modelMasterEntity.id) + .select( + Projections.constructor( + ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath)) .from(modelMasterEntity) .where( modelMasterEntity.step2EndDttm.isNotNull(), diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java index a10caa8..67517fe 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java @@ -1,10 +1,11 @@ package com.kamco.cd.training.postgres.repository.train; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.util.List; public interface ModelTrainMetricsJobRepositoryCustom { - List getTrainMetricSaveNotYetModelIds(); + List getTrainMetricSaveNotYetModelIds(); void insertModelMetricsTrain(List batchArgs); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java index c20bc73..d650bc7 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java @@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; +import com.querydsl.core.types.Projections; import com.querydsl.jpa.impl.JPAQueryFactory; import java.util.List; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; @@ -23,9 +25,11 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor } @Override - public List getTrainMetricSaveNotYetModelIds() { + public List getTrainMetricSaveNotYetModelIds() { return queryFactory - .select(modelMasterEntity.id) + .select( + Projections.constructor( + ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath)) .from(modelMasterEntity) .where( modelMasterEntity.step1EndDttm.isNotNull(), @@ -41,7 +45,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor public void insertModelMetricsTrain(List batchArgs) { String sql = """ - insert into tb_model_matrics_train + insert into tb_model_metrics_train (model_id, epoch, iteration, loss, lr, duration_time) values (?, ?, ?, ?, ?, ?) """; @@ -66,7 +70,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor public void insertModelMetricsValidation(List batchArgs) { String sql = """ - insert into tb_model_matrics_validation + insert into tb_model_metrics_validation (model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall, unchanged_fscore, unchanged_precision, unchanged_recall ) diff --git a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java new file mode 100644 index 0000000..ff6bd90 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainMetricsDto.java @@ -0,0 +1,21 @@ +package com.kamco.cd.training.train.dto; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +public class ModelTrainMetricsDto { + + @Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보") + @Getter + @Setter + @NoArgsConstructor + @AllArgsConstructor + public static class ResponsePathDto { + + private Long modelId; + private String responsePath; + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index c5936bc..167d3a4 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.train.service; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.io.BufferedReader; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -35,66 +36,67 @@ public class ModelTestMetricsJobService { return "local".equalsIgnoreCase(profile); } - // @Scheduled(cron = "0 0/10 * * * *") + // @Scheduled(cron = "0 * * * * *") public void findTestValidMetricCsvFiles() { - if (isLocalProfile()) { - return; - } + // if (isLocalProfile()) { + // return; + // } - List modelIds = - modelTestMetricsJobCoreService - .getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함 + List modelIds = + modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds(); if (modelIds.isEmpty()) { return; } - String localPath = "C:\\data\\upload\\test.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { + for (ResponsePathDto modelInfo : modelIds) { - log.info("### localPath={}", localPath); - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + String testPath = modelInfo.getResponsePath() + "/metrics/test.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) { - List batchArgs = new ArrayList<>(); + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - for (CSVRecord record : parser) { + List batchArgs = new ArrayList<>(); - String model = record.get("model"); - long TP = Long.parseLong(record.get("TP")); - long FP = Long.parseLong(record.get("FP")); - long FN = Long.parseLong(record.get("FN")); - float precision = Float.parseFloat(record.get("precision")); - float recall = Float.parseFloat(record.get("recall")); - float f1_score = Float.parseFloat(record.get("f1_score")); - float accuracy = Float.parseFloat(record.get("accuracy")); - float iou = Float.parseFloat(record.get("iou")); - long detection_count = Long.parseLong(record.get("detection_count")); - long gt_count = Long.parseLong(record.get("gt_count")); + for (CSVRecord record : parser) { - batchArgs.add( - new Object[] { - modelIds.getFirst(), - model, - TP, - FP, - FN, - precision, - recall, - f1_score, - accuracy, - iou, - detection_count, - gt_count - }); + String model = record.get("model"); + long TP = Long.parseLong(record.get("TP")); + long FP = Long.parseLong(record.get("FP")); + long FN = Long.parseLong(record.get("FN")); + float precision = Float.parseFloat(record.get("precision")); + float recall = Float.parseFloat(record.get("recall")); + float f1_score = Float.parseFloat(record.get("f1_score")); + float accuracy = Float.parseFloat(record.get("accuracy")); + float iou = Float.parseFloat(record.get("iou")); + long detection_count = Long.parseLong(record.get("detection_count")); + long gt_count = Long.parseLong(record.get("gt_count")); + + batchArgs.add( + new Object[] { + modelInfo.getModelId(), + model, + TP, + FP, + FN, + precision, + recall, + f1_score, + accuracy, + iou, + detection_count, + gt_count + }); + } + + modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); + + } catch (IOException e) { + throw new RuntimeException(e); } - modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); - - } catch (IOException e) { - throw new RuntimeException(e); + modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2"); } - - modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2"); } } diff --git a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index 319a5fd..2132af7 100644 --- a/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.train.service; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; +import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import java.io.BufferedReader; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -35,96 +36,97 @@ public class ModelTrainMetricsJobService { return "local".equalsIgnoreCase(profile); } - // @Scheduled(cron = "0 0/10 * * * *") + // @Scheduled(cron = "0 * * * * *") public void findTrainValidMetricCsvFiles() { - if (isLocalProfile()) { - return; - } + // if (isLocalProfile()) { + // return; + // } - List modelIds = - modelTrainMetricsJobCoreService - .getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함 + List modelIds = + modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds(); if (modelIds.isEmpty()) { return; } - String localPath = "C:\\data\\upload\\train.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { + for (ResponsePathDto modelInfo : modelIds) { - log.info("### localPath={}", localPath); - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + String trainPath = modelInfo.getResponsePath() + "/metrics/train.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) { - List batchArgs = new ArrayList<>(); + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - for (CSVRecord record : parser) { + List batchArgs = new ArrayList<>(); - int epoch = Integer.parseInt(record.get("Epoch")); - long iteration = Long.parseLong(record.get("Iteration")); - double Loss = Double.parseDouble(record.get("Loss")); - double LR = Double.parseDouble(record.get("LR")); - float time = Float.parseFloat(record.get("Time")); + for (CSVRecord record : parser) { - batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time}); + int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기 + long iteration = Long.parseLong(record.get("Iteration")); + double Loss = Double.parseDouble(record.get("Loss")); + double LR = Double.parseDouble(record.get("LR")); + float time = Float.parseFloat(record.get("Time")); + + batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time}); + } + + modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); + + } catch (IOException e) { + throw new RuntimeException(e); } - modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); + String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv"; + try (BufferedReader reader = + Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { - } catch (IOException e) { - throw new RuntimeException(e); - } + CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - String validationPath = "C:\\data\\upload\\val.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { + List batchArgs = new ArrayList<>(); - log.info("### validationPath={}", validationPath); - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); + for (CSVRecord record : parser) { - List batchArgs = new ArrayList<>(); + int epoch = Integer.parseInt(record.get("Epoch")); + float aAcc = Float.parseFloat(record.get("aAcc")); + float mFscore = Float.parseFloat(record.get("mFscore")); + float mPrecision = Float.parseFloat(record.get("mPrecision")); + float mRecall = Float.parseFloat(record.get("mRecall")); + float mIoU = Float.parseFloat(record.get("mIoU")); + float mAcc = Float.parseFloat(record.get("mAcc")); + float changed_fscore = Float.parseFloat(record.get("changed_fscore")); + float changed_precision = Float.parseFloat(record.get("changed_precision")); + float changed_recall = Float.parseFloat(record.get("changed_recall")); + float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore")); + float unchanged_precision = Float.parseFloat(record.get("unchanged_precision")); + float unchanged_recall = Float.parseFloat(record.get("unchanged_recall")); - for (CSVRecord record : parser) { + batchArgs.add( + new Object[] { + modelInfo.getModelId(), + epoch, + aAcc, + mFscore, + mPrecision, + mRecall, + mIoU, + mAcc, + changed_fscore, + changed_precision, + changed_recall, + unchanged_fscore, + unchanged_precision, + unchanged_recall + }); + } - int epoch = Integer.parseInt(record.get("Epoch")); - float aAcc = Float.parseFloat(record.get("aAcc")); - float mFscore = Float.parseFloat(record.get("mFscore")); - float mPrecision = Float.parseFloat(record.get("mPrecision")); - float mRecall = Float.parseFloat(record.get("mRecall")); - float mIoU = Float.parseFloat(record.get("mIoU")); - float mAcc = Float.parseFloat(record.get("mAcc")); - float changed_fscore = Float.parseFloat(record.get("changed_fscore")); - float changed_precision = Float.parseFloat(record.get("changed_precision")); - float changed_recall = Float.parseFloat(record.get("changed_recall")); - float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore")); - float unchanged_precision = Float.parseFloat(record.get("unchanged_precision")); - float unchanged_recall = Float.parseFloat(record.get("unchanged_recall")); + modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); - batchArgs.add( - new Object[] { - modelIds.getFirst(), - epoch, - aAcc, - mFscore, - mPrecision, - mRecall, - mIoU, - mAcc, - changed_fscore, - changed_precision, - changed_recall, - unchanged_fscore, - unchanged_precision, - unchanged_recall - }); + } catch (IOException e) { + throw new RuntimeException(e); } - modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); - - } catch (IOException e) { - throw new RuntimeException(e); + modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn( + modelInfo.getModelId(), "step1"); } - - modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1"); } }