csv 파일 읽는 경로 읽어서 수정, train은 epoch + 1 해서 저장
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -18,7 +19,7 @@ public class ModelTestMetricsJobCoreService {
|
||||
}
|
||||
|
||||
// Test 로직 시작
|
||||
public List<Long> getTestMetricSaveNotYetModelIds() {
|
||||
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -12,7 +13,7 @@ public class ModelTrainMetricsJobCoreService {
|
||||
|
||||
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
|
||||
|
||||
public List<Long> getTrainMetricSaveNotYetModelIds() {
|
||||
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelTestMetricsJobRepositoryCustom {
|
||||
|
||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||
|
||||
List<Long> getTestMetricSaveNotYetModelIds();
|
||||
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
||||
|
||||
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||
@@ -36,9 +38,11 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Long> getTestMetricSaveNotYetModelIds() {
|
||||
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||
return queryFactory
|
||||
.select(modelMasterEntity.id)
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelTrainMetricsJobRepositoryCustom {
|
||||
|
||||
List<Long> getTrainMetricSaveNotYetModelIds();
|
||||
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||
@@ -23,9 +25,11 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Long> getTrainMetricSaveNotYetModelIds() {
|
||||
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||
return queryFactory
|
||||
.select(modelMasterEntity.id)
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
||||
@@ -41,7 +45,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||
String sql =
|
||||
"""
|
||||
insert into tb_model_matrics_train
|
||||
insert into tb_model_metrics_train
|
||||
(model_id, epoch, iteration, loss, lr, duration_time)
|
||||
values (?, ?, ?, ?, ?, ?)
|
||||
""";
|
||||
@@ -66,7 +70,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||
String sql =
|
||||
"""
|
||||
insert into tb_model_matrics_validation
|
||||
insert into tb_model_metrics_validation
|
||||
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
|
||||
unchanged_fscore, unchanged_precision, unchanged_recall
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
public class ModelTrainMetricsDto {
|
||||
|
||||
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ResponsePathDto {
|
||||
|
||||
private Long modelId;
|
||||
private String responsePath;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@@ -35,66 +36,67 @@ public class ModelTestMetricsJobService {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0/10 * * * *")
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() {
|
||||
if (isLocalProfile()) {
|
||||
return;
|
||||
}
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<Long> modelIds =
|
||||
modelTestMetricsJobCoreService
|
||||
.getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String localPath = "C:\\data\\upload\\test.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
log.info("### localPath={}", localPath);
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
String model = record.get("model");
|
||||
long TP = Long.parseLong(record.get("TP"));
|
||||
long FP = Long.parseLong(record.get("FP"));
|
||||
long FN = Long.parseLong(record.get("FN"));
|
||||
float precision = Float.parseFloat(record.get("precision"));
|
||||
float recall = Float.parseFloat(record.get("recall"));
|
||||
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||
float iou = Float.parseFloat(record.get("iou"));
|
||||
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelIds.getFirst(),
|
||||
model,
|
||||
TP,
|
||||
FP,
|
||||
FN,
|
||||
precision,
|
||||
recall,
|
||||
f1_score,
|
||||
accuracy,
|
||||
iou,
|
||||
detection_count,
|
||||
gt_count
|
||||
});
|
||||
String model = record.get("model");
|
||||
long TP = Long.parseLong(record.get("TP"));
|
||||
long FP = Long.parseLong(record.get("FP"));
|
||||
long FN = Long.parseLong(record.get("FN"));
|
||||
float precision = Float.parseFloat(record.get("precision"));
|
||||
float recall = Float.parseFloat(record.get("recall"));
|
||||
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||
float iou = Float.parseFloat(record.get("iou"));
|
||||
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelInfo.getModelId(),
|
||||
model,
|
||||
TP,
|
||||
FP,
|
||||
FN,
|
||||
precision,
|
||||
recall,
|
||||
f1_score,
|
||||
accuracy,
|
||||
iou,
|
||||
detection_count,
|
||||
gt_count
|
||||
});
|
||||
}
|
||||
|
||||
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
|
||||
}
|
||||
|
||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@@ -35,96 +36,97 @@ public class ModelTrainMetricsJobService {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0/10 * * * *")
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTrainValidMetricCsvFiles() {
|
||||
if (isLocalProfile()) {
|
||||
return;
|
||||
}
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<Long> modelIds =
|
||||
modelTrainMetricsJobCoreService
|
||||
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String localPath = "C:\\data\\upload\\train.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
log.info("### localPath={}", localPath);
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
String trainPath = modelInfo.getResponsePath() + "/metrics/train.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||
long iteration = Long.parseLong(record.get("Iteration"));
|
||||
double Loss = Double.parseDouble(record.get("Loss"));
|
||||
double LR = Double.parseDouble(record.get("LR"));
|
||||
float time = Float.parseFloat(record.get("Time"));
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time});
|
||||
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
|
||||
long iteration = Long.parseLong(record.get("Iteration"));
|
||||
double Loss = Double.parseDouble(record.get("Loss"));
|
||||
double LR = Double.parseDouble(record.get("LR"));
|
||||
float time = Float.parseFloat(record.get("Time"));
|
||||
|
||||
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
||||
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
String validationPath = "C:\\data\\upload\\val.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
log.info("### validationPath={}", validationPath);
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||
float aAcc = Float.parseFloat(record.get("aAcc"));
|
||||
float mFscore = Float.parseFloat(record.get("mFscore"));
|
||||
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
||||
float mRecall = Float.parseFloat(record.get("mRecall"));
|
||||
float mIoU = Float.parseFloat(record.get("mIoU"));
|
||||
float mAcc = Float.parseFloat(record.get("mAcc"));
|
||||
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
||||
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
||||
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
||||
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
||||
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
||||
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelInfo.getModelId(),
|
||||
epoch,
|
||||
aAcc,
|
||||
mFscore,
|
||||
mPrecision,
|
||||
mRecall,
|
||||
mIoU,
|
||||
mAcc,
|
||||
changed_fscore,
|
||||
changed_precision,
|
||||
changed_recall,
|
||||
unchanged_fscore,
|
||||
unchanged_precision,
|
||||
unchanged_recall
|
||||
});
|
||||
}
|
||||
|
||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||
float aAcc = Float.parseFloat(record.get("aAcc"));
|
||||
float mFscore = Float.parseFloat(record.get("mFscore"));
|
||||
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
||||
float mRecall = Float.parseFloat(record.get("mRecall"));
|
||||
float mIoU = Float.parseFloat(record.get("mIoU"));
|
||||
float mAcc = Float.parseFloat(record.get("mAcc"));
|
||||
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
||||
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
||||
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
||||
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
||||
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
||||
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
||||
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelIds.getFirst(),
|
||||
epoch,
|
||||
aAcc,
|
||||
mFscore,
|
||||
mPrecision,
|
||||
mRecall,
|
||||
mIoU,
|
||||
mAcc,
|
||||
changed_fscore,
|
||||
changed_precision,
|
||||
changed_recall,
|
||||
unchanged_fscore,
|
||||
unchanged_precision,
|
||||
unchanged_recall
|
||||
});
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
||||
modelInfo.getModelId(), "step1");
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user