test metrics 스케줄 추가

This commit is contained in:
2026-02-11 19:09:58 +09:00
parent 207cc47f1b
commit 95548223cd
5 changed files with 89 additions and 157 deletions

View File

@@ -0,0 +1,28 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTestMetricsJobCoreService {
private final ModelTestMetricsJobRepository modelTestMetricsJobRepository;
@Transactional
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
}
// Test 로직 시작
public List<Long> getTestMetricSaveNotYetModelIds() {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
}
public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
}
}

View File

@@ -4,11 +4,9 @@ import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom { public interface ModelTestMetricsJobRepositoryCustom {
List<Long> getTrainMetricSaveNotYetModelIds();
void insertModelMetricsTrain(List<Object[]> batchArgs);
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
void insertModelMetricsValidation(List<Object[]> batchArgs); List<Long> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs);
} }

View File

@@ -22,33 +22,6 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
this.jdbcTemplate = jdbcTemplate; this.jdbcTemplate = jdbcTemplate;
} }
@Override
public List<Long> getTrainMetricSaveNotYetModelIds() {
return queryFactory
.select(modelMasterEntity.id)
.from(modelMasterEntity)
.where(
modelMasterEntity.step1EndDttm.isNotNull(),
modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step1MetricSaveYn
.isNull()
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_matrics_train
(model_id, epoch, iteration, loss, lr, duration_time)
values (?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override @Override
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) { public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
queryFactory queryFactory
@@ -63,14 +36,29 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
} }
@Override @Override
public void insertModelMetricsValidation(List<Object[]> batchArgs) { public List<Long> getTestMetricSaveNotYetModelIds() {
return queryFactory
.select(modelMasterEntity.id)
.from(modelMasterEntity)
.where(
modelMasterEntity.step2EndDttm.isNotNull(),
modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step2MetricSaveYn
.isNull()
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTest(List<Object[]> batchArgs) {
String sql = String sql =
""" """
insert into tb_model_matrics_validation insert into tb_model_metrics_test
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall, (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
unchanged_fscore, unchanged_precision, unchanged_recall detection_count, gt_count
) )
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""; """;
jdbcTemplate.batchUpdate(sql, batchArgs); jdbcTemplate.batchUpdate(sql, batchArgs);

View File

@@ -1,9 +1,6 @@
package com.kamco.cd.training.schedule.service; package com.kamco.cd.training.schedule.service;
import com.jcraft.jsch.ChannelSftp; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -24,7 +21,7 @@ import org.springframework.stereotype.Service;
@RequiredArgsConstructor @RequiredArgsConstructor
public class ModelTestMetricsJobService { public class ModelTestMetricsJobService {
private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService; private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService;
@Value("${spring.profiles.active}") @Value("${spring.profiles.active}")
private String profile; private String profile;
@@ -38,42 +35,21 @@ public class ModelTestMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 * * * * *") // @Scheduled(cron = "0 0/10 * * * *")
public void findTestValidMetricCsvFiles() { public void findTestValidMetricCsvFiles() {
// if (isLocalProfile()) { if (isLocalProfile()) {
// return; return;
// } }
List<Long> modelIds = List<Long> modelIds =
modelTrainMetricsJobCoreService modelTestMetricsJobCoreService
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함 .getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
if (modelIds.isEmpty()) { if (modelIds.isEmpty()) {
return; return;
} }
Session session = null; String localPath = "C:\\data\\upload\\test.csv";
ChannelSftp sftp = null;
JSch jsch = new JSch();
// try {
// session = jsch.getSession("kcomu", "192.168.2.86", 22);
// session.setPassword("Kamco2025!");
//
// Properties config = new Properties();
// config.put("StrictHostKeyChecking", "no");
// session.setConfig(config);
//
// session.connect();
//
// sftp = (ChannelSftp) session.openChannel("sftp");
// sftp.connect();
// InputStream csvInputStream =
// sftp.get("/home/kcomu/data/response/test/metrics/train.csv");
String localPath = "C:\\data\\upload\\train.csv";
try (BufferedReader reader = try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
@@ -84,73 +60,41 @@ public class ModelTestMetricsJobService {
for (CSVRecord record : parser) { for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch")); String model = record.get("model");
long iteration = Long.parseLong(record.get("Iteration")); long TP = Long.parseLong(record.get("TP"));
double Loss = Double.parseDouble(record.get("Loss")); long FP = Long.parseLong(record.get("FP"));
double LR = Double.parseDouble(record.get("LR")); long FN = Long.parseLong(record.get("FN"));
float time = Float.parseFloat(record.get("Time")); float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time}); float f1_score = Float.parseFloat(record.get("f1_score"));
} float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
} catch (IOException e) {
throw new RuntimeException(e);
}
String validationPath = "C:\\data\\upload\\val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
log.info("### validationPath={}", validationPath);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
float aAcc = Float.parseFloat(record.get("aAcc"));
float mFscore = Float.parseFloat(record.get("mFscore"));
float mPrecision = Float.parseFloat(record.get("mPrecision"));
float mRecall = Float.parseFloat(record.get("mRecall"));
float mIoU = Float.parseFloat(record.get("mIoU"));
float mAcc = Float.parseFloat(record.get("mAcc"));
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
float changed_precision = Float.parseFloat(record.get("changed_precision"));
float changed_recall = Float.parseFloat(record.get("changed_recall"));
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
batchArgs.add( batchArgs.add(
new Object[] { new Object[] {
modelIds.getFirst(), modelIds.getFirst(),
epoch, model,
aAcc, TP,
mFscore, FP,
mPrecision, FN,
mRecall, precision,
mIoU, recall,
mAcc, f1_score,
changed_fscore, accuracy,
changed_precision, iou,
changed_recall, detection_count,
unchanged_fscore, gt_count
unchanged_precision,
unchanged_recall
}); });
} }
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
// } catch (JSchException | SftpException e) {
// throw new RuntimeException(e); modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2");
// }
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
} }
} }

View File

@@ -1,8 +1,5 @@
package com.kamco.cd.training.schedule.service; package com.kamco.cd.training.schedule.service;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
@@ -38,7 +35,7 @@ public class ModelTrainMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 0/5 * * * *") // @Scheduled(cron = "0 0/10 * * * *")
public void findTrainValidMetricCsvFiles() { public void findTrainValidMetricCsvFiles() {
if (isLocalProfile()) { if (isLocalProfile()) {
return; return;
@@ -52,27 +49,6 @@ public class ModelTrainMetricsJobService {
return; return;
} }
Session session = null;
ChannelSftp sftp = null;
JSch jsch = new JSch();
// try {
// session = jsch.getSession("kcomu", "192.168.2.86", 22);
// session.setPassword("Kamco2025!");
//
// Properties config = new Properties();
// config.put("StrictHostKeyChecking", "no");
// session.setConfig(config);
//
// session.connect();
//
// sftp = (ChannelSftp) session.openChannel("sftp");
// sftp.connect();
// InputStream csvInputStream =
// sftp.get("/home/kcomu/data/response/test/metrics/train.csv");
String localPath = "C:\\data\\upload\\train.csv"; String localPath = "C:\\data\\upload\\train.csv";
try (BufferedReader reader = try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
@@ -148,9 +124,7 @@ public class ModelTrainMetricsJobService {
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
// } catch (JSchException | SftpException e) {
// throw new RuntimeException(e);
// }
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1"); modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
} }
} }