diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java new file mode 100644 index 0000000..4e7450f --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java @@ -0,0 +1,28 @@ +package com.kamco.cd.training.postgres.core; + +import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +public class ModelTestMetricsJobCoreService { + + private final ModelTestMetricsJobRepository modelTestMetricsJobRepository; + + @Transactional + public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) { + modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo); + } + + // Test 로직 시작 + public List getTestMetricSaveNotYetModelIds() { + return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds(); + } + + public void insertModelMetricsTest(List batchArgs) { + modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs); + } +} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java index aff4e45..5a34eca 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java @@ -4,11 +4,9 @@ import java.util.List; public interface ModelTestMetricsJobRepositoryCustom { - List getTrainMetricSaveNotYetModelIds(); - - void insertModelMetricsTrain(List batchArgs); - void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); - void insertModelMetricsValidation(List batchArgs); + List getTestMetricSaveNotYetModelIds(); + + void insertModelMetricsTest(List batchArgs); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java index 7029559..d30179f 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java @@ -22,33 +22,6 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport this.jdbcTemplate = jdbcTemplate; } - @Override - public List getTrainMetricSaveNotYetModelIds() { - return queryFactory - .select(modelMasterEntity.id) - .from(modelMasterEntity) - .where( - modelMasterEntity.step1EndDttm.isNotNull(), - modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()), - modelMasterEntity - .step1MetricSaveYn - .isNull() - .or(modelMasterEntity.step1MetricSaveYn.isFalse())) - .fetch(); - } - - @Override - public void insertModelMetricsTrain(List batchArgs) { - String sql = - """ - insert into tb_model_matrics_train - (model_id, epoch, iteration, loss, lr, duration_time) - values (?, ?, ?, ?, ?, ?) - """; - - jdbcTemplate.batchUpdate(sql, batchArgs); - } - @Override public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) { queryFactory @@ -63,15 +36,30 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport } @Override - public void insertModelMetricsValidation(List batchArgs) { + public List getTestMetricSaveNotYetModelIds() { + return queryFactory + .select(modelMasterEntity.id) + .from(modelMasterEntity) + .where( + modelMasterEntity.step2EndDttm.isNotNull(), + modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()), + modelMasterEntity + .step2MetricSaveYn + .isNull() + .or(modelMasterEntity.step2MetricSaveYn.isFalse())) + .fetch(); + } + + @Override + public void insertModelMetricsTest(List batchArgs) { String sql = """ - insert into tb_model_matrics_validation - (model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall, - unchanged_fscore, unchanged_precision, unchanged_recall - ) - values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """; + insert into tb_model_metrics_test + (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou, + detection_count, gt_count + ) + values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """; jdbcTemplate.batchUpdate(sql, batchArgs); } diff --git a/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java index e0ae083..15e5011 100644 --- a/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java @@ -1,9 +1,6 @@ package com.kamco.cd.training.schedule.service; -import com.jcraft.jsch.ChannelSftp; -import com.jcraft.jsch.JSch; -import com.jcraft.jsch.Session; -import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; +import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import java.io.BufferedReader; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -24,7 +21,7 @@ import org.springframework.stereotype.Service; @RequiredArgsConstructor public class ModelTestMetricsJobService { - private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService; + private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService; @Value("${spring.profiles.active}") private String profile; @@ -38,42 +35,21 @@ public class ModelTestMetricsJobService { return "local".equalsIgnoreCase(profile); } - // @Scheduled(cron = "0 * * * * *") + // @Scheduled(cron = "0 0/10 * * * *") public void findTestValidMetricCsvFiles() { - // if (isLocalProfile()) { - // return; - // } + if (isLocalProfile()) { + return; + } List modelIds = - modelTrainMetricsJobCoreService - .getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함 + modelTestMetricsJobCoreService + .getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함 if (modelIds.isEmpty()) { return; } - Session session = null; - ChannelSftp sftp = null; - - JSch jsch = new JSch(); - - // try { - // session = jsch.getSession("kcomu", "192.168.2.86", 22); - // session.setPassword("Kamco2025!"); - // - // Properties config = new Properties(); - // config.put("StrictHostKeyChecking", "no"); - // session.setConfig(config); - // - // session.connect(); - // - // sftp = (ChannelSftp) session.openChannel("sftp"); - // sftp.connect(); - - // InputStream csvInputStream = - // sftp.get("/home/kcomu/data/response/test/metrics/train.csv"); - - String localPath = "C:\\data\\upload\\train.csv"; + String localPath = "C:\\data\\upload\\test.csv"; try (BufferedReader reader = Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { @@ -84,73 +60,41 @@ public class ModelTestMetricsJobService { for (CSVRecord record : parser) { - int epoch = Integer.parseInt(record.get("Epoch")); - long iteration = Long.parseLong(record.get("Iteration")); - double Loss = Double.parseDouble(record.get("Loss")); - double LR = Double.parseDouble(record.get("LR")); - float time = Float.parseFloat(record.get("Time")); - - batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time}); - } - - modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); - - } catch (IOException e) { - throw new RuntimeException(e); - } - - String validationPath = "C:\\data\\upload\\val.csv"; - try (BufferedReader reader = - Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { - - log.info("### validationPath={}", validationPath); - CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); - - List batchArgs = new ArrayList<>(); - - for (CSVRecord record : parser) { - - int epoch = Integer.parseInt(record.get("Epoch")); - float aAcc = Float.parseFloat(record.get("aAcc")); - float mFscore = Float.parseFloat(record.get("mFscore")); - float mPrecision = Float.parseFloat(record.get("mPrecision")); - float mRecall = Float.parseFloat(record.get("mRecall")); - float mIoU = Float.parseFloat(record.get("mIoU")); - float mAcc = Float.parseFloat(record.get("mAcc")); - float changed_fscore = Float.parseFloat(record.get("changed_fscore")); - float changed_precision = Float.parseFloat(record.get("changed_precision")); - float changed_recall = Float.parseFloat(record.get("changed_recall")); - float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore")); - float unchanged_precision = Float.parseFloat(record.get("unchanged_precision")); - float unchanged_recall = Float.parseFloat(record.get("unchanged_recall")); + String model = record.get("model"); + long TP = Long.parseLong(record.get("TP")); + long FP = Long.parseLong(record.get("FP")); + long FN = Long.parseLong(record.get("FN")); + float precision = Float.parseFloat(record.get("precision")); + float recall = Float.parseFloat(record.get("recall")); + float f1_score = Float.parseFloat(record.get("f1_score")); + float accuracy = Float.parseFloat(record.get("accuracy")); + float iou = Float.parseFloat(record.get("iou")); + long detection_count = Long.parseLong(record.get("detection_count")); + long gt_count = Long.parseLong(record.get("gt_count")); batchArgs.add( new Object[] { modelIds.getFirst(), - epoch, - aAcc, - mFscore, - mPrecision, - mRecall, - mIoU, - mAcc, - changed_fscore, - changed_precision, - changed_recall, - unchanged_fscore, - unchanged_precision, - unchanged_recall + model, + TP, + FP, + FN, + precision, + recall, + f1_score, + accuracy, + iou, + detection_count, + gt_count }); } - modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); + modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); } catch (IOException e) { throw new RuntimeException(e); } - // } catch (JSchException | SftpException e) { - // throw new RuntimeException(e); - // } - modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1"); + + modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2"); } } diff --git a/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java index a1a9485..ecb7c27 100644 --- a/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java @@ -1,8 +1,5 @@ package com.kamco.cd.training.schedule.service; -import com.jcraft.jsch.ChannelSftp; -import com.jcraft.jsch.JSch; -import com.jcraft.jsch.Session; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import java.io.BufferedReader; import java.io.IOException; @@ -38,7 +35,7 @@ public class ModelTrainMetricsJobService { return "local".equalsIgnoreCase(profile); } - // @Scheduled(cron = "0 0/5 * * * *") + // @Scheduled(cron = "0 0/10 * * * *") public void findTrainValidMetricCsvFiles() { if (isLocalProfile()) { return; @@ -52,27 +49,6 @@ public class ModelTrainMetricsJobService { return; } - Session session = null; - ChannelSftp sftp = null; - - JSch jsch = new JSch(); - - // try { - // session = jsch.getSession("kcomu", "192.168.2.86", 22); - // session.setPassword("Kamco2025!"); - // - // Properties config = new Properties(); - // config.put("StrictHostKeyChecking", "no"); - // session.setConfig(config); - // - // session.connect(); - // - // sftp = (ChannelSftp) session.openChannel("sftp"); - // sftp.connect(); - - // InputStream csvInputStream = - // sftp.get("/home/kcomu/data/response/test/metrics/train.csv"); - String localPath = "C:\\data\\upload\\train.csv"; try (BufferedReader reader = Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { @@ -148,9 +124,7 @@ public class ModelTrainMetricsJobService { } catch (IOException e) { throw new RuntimeException(e); } - // } catch (JSchException | SftpException e) { - // throw new RuntimeException(e); - // } + modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1"); } }