Merge pull request 'test metrics 스케줄 추가' (#37) from feat/training_260202 into develop
Reviewed-on: #37
This commit was merged in pull request #37.
This commit is contained in:
@@ -0,0 +1,28 @@
|
|||||||
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class ModelTestMetricsJobCoreService {
|
||||||
|
|
||||||
|
private final ModelTestMetricsJobRepository modelTestMetricsJobRepository;
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||||
|
modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 로직 시작
|
||||||
|
public List<Long> getTestMetricSaveNotYetModelIds() {
|
||||||
|
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||||
|
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,11 +4,9 @@ import java.util.List;
|
|||||||
|
|
||||||
public interface ModelTestMetricsJobRepositoryCustom {
|
public interface ModelTestMetricsJobRepositoryCustom {
|
||||||
|
|
||||||
List<Long> getTrainMetricSaveNotYetModelIds();
|
|
||||||
|
|
||||||
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
|
||||||
|
|
||||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||||
|
|
||||||
void insertModelMetricsValidation(List<Object[]> batchArgs);
|
List<Long> getTestMetricSaveNotYetModelIds();
|
||||||
|
|
||||||
|
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,33 +22,6 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
|||||||
this.jdbcTemplate = jdbcTemplate;
|
this.jdbcTemplate = jdbcTemplate;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Long> getTrainMetricSaveNotYetModelIds() {
|
|
||||||
return queryFactory
|
|
||||||
.select(modelMasterEntity.id)
|
|
||||||
.from(modelMasterEntity)
|
|
||||||
.where(
|
|
||||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
|
||||||
modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()),
|
|
||||||
modelMasterEntity
|
|
||||||
.step1MetricSaveYn
|
|
||||||
.isNull()
|
|
||||||
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
|
|
||||||
.fetch();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
|
||||||
String sql =
|
|
||||||
"""
|
|
||||||
insert into tb_model_matrics_train
|
|
||||||
(model_id, epoch, iteration, loss, lr, duration_time)
|
|
||||||
values (?, ?, ?, ?, ?, ?)
|
|
||||||
""";
|
|
||||||
|
|
||||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||||
queryFactory
|
queryFactory
|
||||||
@@ -63,15 +36,30 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
public List<Long> getTestMetricSaveNotYetModelIds() {
|
||||||
|
return queryFactory
|
||||||
|
.select(modelMasterEntity.id)
|
||||||
|
.from(modelMasterEntity)
|
||||||
|
.where(
|
||||||
|
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||||
|
modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()),
|
||||||
|
modelMasterEntity
|
||||||
|
.step2MetricSaveYn
|
||||||
|
.isNull()
|
||||||
|
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
|
||||||
|
.fetch();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||||
String sql =
|
String sql =
|
||||||
"""
|
"""
|
||||||
insert into tb_model_matrics_validation
|
insert into tb_model_metrics_test
|
||||||
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
|
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
|
||||||
unchanged_fscore, unchanged_precision, unchanged_recall
|
detection_count, gt_count
|
||||||
)
|
)
|
||||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""";
|
""";
|
||||||
|
|
||||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package com.kamco.cd.training.schedule.service;
|
package com.kamco.cd.training.schedule.service;
|
||||||
|
|
||||||
import com.jcraft.jsch.ChannelSftp;
|
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||||
import com.jcraft.jsch.JSch;
|
|
||||||
import com.jcraft.jsch.Session;
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
@@ -24,7 +21,7 @@ import org.springframework.stereotype.Service;
|
|||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class ModelTestMetricsJobService {
|
public class ModelTestMetricsJobService {
|
||||||
|
|
||||||
private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService;
|
private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService;
|
||||||
|
|
||||||
@Value("${spring.profiles.active}")
|
@Value("${spring.profiles.active}")
|
||||||
private String profile;
|
private String profile;
|
||||||
@@ -38,42 +35,21 @@ public class ModelTestMetricsJobService {
|
|||||||
return "local".equalsIgnoreCase(profile);
|
return "local".equalsIgnoreCase(profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Scheduled(cron = "0 * * * * *")
|
// @Scheduled(cron = "0 0/10 * * * *")
|
||||||
public void findTestValidMetricCsvFiles() {
|
public void findTestValidMetricCsvFiles() {
|
||||||
// if (isLocalProfile()) {
|
if (isLocalProfile()) {
|
||||||
// return;
|
return;
|
||||||
// }
|
}
|
||||||
|
|
||||||
List<Long> modelIds =
|
List<Long> modelIds =
|
||||||
modelTrainMetricsJobCoreService
|
modelTestMetricsJobCoreService
|
||||||
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
.getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
||||||
|
|
||||||
if (modelIds.isEmpty()) {
|
if (modelIds.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Session session = null;
|
String localPath = "C:\\data\\upload\\test.csv";
|
||||||
ChannelSftp sftp = null;
|
|
||||||
|
|
||||||
JSch jsch = new JSch();
|
|
||||||
|
|
||||||
// try {
|
|
||||||
// session = jsch.getSession("kcomu", "192.168.2.86", 22);
|
|
||||||
// session.setPassword("Kamco2025!");
|
|
||||||
//
|
|
||||||
// Properties config = new Properties();
|
|
||||||
// config.put("StrictHostKeyChecking", "no");
|
|
||||||
// session.setConfig(config);
|
|
||||||
//
|
|
||||||
// session.connect();
|
|
||||||
//
|
|
||||||
// sftp = (ChannelSftp) session.openChannel("sftp");
|
|
||||||
// sftp.connect();
|
|
||||||
|
|
||||||
// InputStream csvInputStream =
|
|
||||||
// sftp.get("/home/kcomu/data/response/test/metrics/train.csv");
|
|
||||||
|
|
||||||
String localPath = "C:\\data\\upload\\train.csv";
|
|
||||||
try (BufferedReader reader =
|
try (BufferedReader reader =
|
||||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
@@ -84,73 +60,41 @@ public class ModelTestMetricsJobService {
|
|||||||
|
|
||||||
for (CSVRecord record : parser) {
|
for (CSVRecord record : parser) {
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
String model = record.get("model");
|
||||||
long iteration = Long.parseLong(record.get("Iteration"));
|
long TP = Long.parseLong(record.get("TP"));
|
||||||
double Loss = Double.parseDouble(record.get("Loss"));
|
long FP = Long.parseLong(record.get("FP"));
|
||||||
double LR = Double.parseDouble(record.get("LR"));
|
long FN = Long.parseLong(record.get("FN"));
|
||||||
float time = Float.parseFloat(record.get("Time"));
|
float precision = Float.parseFloat(record.get("precision"));
|
||||||
|
float recall = Float.parseFloat(record.get("recall"));
|
||||||
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time});
|
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||||
}
|
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||||
|
float iou = Float.parseFloat(record.get("iou"));
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||||
|
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
String validationPath = "C:\\data\\upload\\val.csv";
|
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
log.info("### validationPath={}", validationPath);
|
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
|
||||||
float aAcc = Float.parseFloat(record.get("aAcc"));
|
|
||||||
float mFscore = Float.parseFloat(record.get("mFscore"));
|
|
||||||
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
|
||||||
float mRecall = Float.parseFloat(record.get("mRecall"));
|
|
||||||
float mIoU = Float.parseFloat(record.get("mIoU"));
|
|
||||||
float mAcc = Float.parseFloat(record.get("mAcc"));
|
|
||||||
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
|
||||||
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
|
||||||
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
|
||||||
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
|
||||||
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
|
||||||
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
|
||||||
|
|
||||||
batchArgs.add(
|
batchArgs.add(
|
||||||
new Object[] {
|
new Object[] {
|
||||||
modelIds.getFirst(),
|
modelIds.getFirst(),
|
||||||
epoch,
|
model,
|
||||||
aAcc,
|
TP,
|
||||||
mFscore,
|
FP,
|
||||||
mPrecision,
|
FN,
|
||||||
mRecall,
|
precision,
|
||||||
mIoU,
|
recall,
|
||||||
mAcc,
|
f1_score,
|
||||||
changed_fscore,
|
accuracy,
|
||||||
changed_precision,
|
iou,
|
||||||
changed_recall,
|
detection_count,
|
||||||
unchanged_fscore,
|
gt_count
|
||||||
unchanged_precision,
|
|
||||||
unchanged_recall
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
// } catch (JSchException | SftpException e) {
|
|
||||||
// throw new RuntimeException(e);
|
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2");
|
||||||
// }
|
|
||||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
package com.kamco.cd.training.schedule.service;
|
package com.kamco.cd.training.schedule.service;
|
||||||
|
|
||||||
import com.jcraft.jsch.ChannelSftp;
|
|
||||||
import com.jcraft.jsch.JSch;
|
|
||||||
import com.jcraft.jsch.Session;
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@@ -38,7 +35,7 @@ public class ModelTrainMetricsJobService {
|
|||||||
return "local".equalsIgnoreCase(profile);
|
return "local".equalsIgnoreCase(profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Scheduled(cron = "0 0/5 * * * *")
|
// @Scheduled(cron = "0 0/10 * * * *")
|
||||||
public void findTrainValidMetricCsvFiles() {
|
public void findTrainValidMetricCsvFiles() {
|
||||||
if (isLocalProfile()) {
|
if (isLocalProfile()) {
|
||||||
return;
|
return;
|
||||||
@@ -52,27 +49,6 @@ public class ModelTrainMetricsJobService {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Session session = null;
|
|
||||||
ChannelSftp sftp = null;
|
|
||||||
|
|
||||||
JSch jsch = new JSch();
|
|
||||||
|
|
||||||
// try {
|
|
||||||
// session = jsch.getSession("kcomu", "192.168.2.86", 22);
|
|
||||||
// session.setPassword("Kamco2025!");
|
|
||||||
//
|
|
||||||
// Properties config = new Properties();
|
|
||||||
// config.put("StrictHostKeyChecking", "no");
|
|
||||||
// session.setConfig(config);
|
|
||||||
//
|
|
||||||
// session.connect();
|
|
||||||
//
|
|
||||||
// sftp = (ChannelSftp) session.openChannel("sftp");
|
|
||||||
// sftp.connect();
|
|
||||||
|
|
||||||
// InputStream csvInputStream =
|
|
||||||
// sftp.get("/home/kcomu/data/response/test/metrics/train.csv");
|
|
||||||
|
|
||||||
String localPath = "C:\\data\\upload\\train.csv";
|
String localPath = "C:\\data\\upload\\train.csv";
|
||||||
try (BufferedReader reader =
|
try (BufferedReader reader =
|
||||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
||||||
@@ -148,9 +124,7 @@ public class ModelTrainMetricsJobService {
|
|||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
// } catch (JSchException | SftpException e) {
|
|
||||||
// throw new RuntimeException(e);
|
|
||||||
// }
|
|
||||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
|
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user