csv 파일 읽는 경로 읽어서 수정, train은 epoch + 1 해서 저장

This commit is contained in:
2026-02-12 15:24:30 +09:00
parent b451f697bc
commit 2df4a7a80b
9 changed files with 160 additions and 123 deletions

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -18,7 +19,7 @@ public class ModelTestMetricsJobCoreService {
} }
// Test 로직 시작 // Test 로직 시작
public List<Long> getTestMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds(); return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -12,7 +13,7 @@ public class ModelTrainMetricsJobCoreService {
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository; private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
public List<Long> getTrainMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds(); return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
} }

View File

@@ -1,12 +1,13 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom { public interface ModelTestMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
List<Long> getTestMetricSaveNotYetModelIds(); List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs); void insertModelMetricsTest(List<Object[]> batchArgs);
} }

View File

@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
@@ -36,9 +38,11 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
} }
@Override @Override
public List<Long> getTestMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return queryFactory return queryFactory
.select(modelMasterEntity.id) .select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity) .from(modelMasterEntity)
.where( .where(
modelMasterEntity.step2EndDttm.isNotNull(), modelMasterEntity.step2EndDttm.isNotNull(),

View File

@@ -1,10 +1,11 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
public interface ModelTrainMetricsJobRepositoryCustom { public interface ModelTrainMetricsJobRepositoryCustom {
List<Long> getTrainMetricSaveNotYetModelIds(); List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
void insertModelMetricsTrain(List<Object[]> batchArgs); void insertModelMetricsTrain(List<Object[]> batchArgs);

View File

@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
@@ -23,9 +25,11 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
} }
@Override @Override
public List<Long> getTrainMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return queryFactory return queryFactory
.select(modelMasterEntity.id) .select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity) .from(modelMasterEntity)
.where( .where(
modelMasterEntity.step1EndDttm.isNotNull(), modelMasterEntity.step1EndDttm.isNotNull(),
@@ -41,7 +45,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
public void insertModelMetricsTrain(List<Object[]> batchArgs) { public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql = String sql =
""" """
insert into tb_model_matrics_train insert into tb_model_metrics_train
(model_id, epoch, iteration, loss, lr, duration_time) (model_id, epoch, iteration, loss, lr, duration_time)
values (?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?)
"""; """;
@@ -66,7 +70,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
public void insertModelMetricsValidation(List<Object[]> batchArgs) { public void insertModelMetricsValidation(List<Object[]> batchArgs) {
String sql = String sql =
""" """
insert into tb_model_matrics_validation insert into tb_model_metrics_validation
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall, (model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
unchanged_fscore, unchanged_precision, unchanged_recall unchanged_fscore, unchanged_precision, unchanged_recall
) )

View File

@@ -0,0 +1,21 @@
package com.kamco.cd.training.train.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class ModelTrainMetricsDto {
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ResponsePathDto {
private Long modelId;
private String responsePath;
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -35,66 +36,67 @@ public class ModelTestMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 0/10 * * * *") // @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() { public void findTestValidMetricCsvFiles() {
if (isLocalProfile()) { // if (isLocalProfile()) {
return; // return;
} // }
List<Long> modelIds = List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
.getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
if (modelIds.isEmpty()) { if (modelIds.isEmpty()) {
return; return;
} }
String localPath = "C:\\data\\upload\\test.csv"; for (ResponsePathDto modelInfo : modelIds) {
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
log.info("### localPath={}", localPath); String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
List<Object[]> batchArgs = new ArrayList<>(); CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
for (CSVRecord record : parser) { List<Object[]> batchArgs = new ArrayList<>();
String model = record.get("model"); for (CSVRecord record : parser) {
long TP = Long.parseLong(record.get("TP"));
long FP = Long.parseLong(record.get("FP"));
long FN = Long.parseLong(record.get("FN"));
float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
float f1_score = Float.parseFloat(record.get("f1_score"));
float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
batchArgs.add( String model = record.get("model");
new Object[] { long TP = Long.parseLong(record.get("TP"));
modelIds.getFirst(), long FP = Long.parseLong(record.get("FP"));
model, long FN = Long.parseLong(record.get("FN"));
TP, float precision = Float.parseFloat(record.get("precision"));
FP, float recall = Float.parseFloat(record.get("recall"));
FN, float f1_score = Float.parseFloat(record.get("f1_score"));
precision, float accuracy = Float.parseFloat(record.get("accuracy"));
recall, float iou = Float.parseFloat(record.get("iou"));
f1_score, long detection_count = Long.parseLong(record.get("detection_count"));
accuracy, long gt_count = Long.parseLong(record.get("gt_count"));
iou,
detection_count, batchArgs.add(
gt_count new Object[] {
}); modelInfo.getModelId(),
model,
TP,
FP,
FN,
precision,
recall,
f1_score,
accuracy,
iou,
detection_count,
gt_count
});
}
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
} }
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
} catch (IOException e) {
throw new RuntimeException(e);
} }
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2");
} }
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -35,96 +36,97 @@ public class ModelTrainMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 0/10 * * * *") // @Scheduled(cron = "0 * * * * *")
public void findTrainValidMetricCsvFiles() { public void findTrainValidMetricCsvFiles() {
if (isLocalProfile()) { // if (isLocalProfile()) {
return; // return;
} // }
List<Long> modelIds = List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
if (modelIds.isEmpty()) { if (modelIds.isEmpty()) {
return; return;
} }
String localPath = "C:\\data\\upload\\train.csv"; for (ResponsePathDto modelInfo : modelIds) {
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
log.info("### localPath={}", localPath); String trainPath = modelInfo.getResponsePath() + "/metrics/train.csv";
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
List<Object[]> batchArgs = new ArrayList<>(); CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
for (CSVRecord record : parser) { List<Object[]> batchArgs = new ArrayList<>();
int epoch = Integer.parseInt(record.get("Epoch")); for (CSVRecord record : parser) {
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time}); int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
}
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
} }
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
} catch (IOException e) { CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
throw new RuntimeException(e);
}
String validationPath = "C:\\data\\upload\\val.csv"; List<Object[]> batchArgs = new ArrayList<>();
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
log.info("### validationPath={}", validationPath); for (CSVRecord record : parser) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>(); int epoch = Integer.parseInt(record.get("Epoch"));
float aAcc = Float.parseFloat(record.get("aAcc"));
float mFscore = Float.parseFloat(record.get("mFscore"));
float mPrecision = Float.parseFloat(record.get("mPrecision"));
float mRecall = Float.parseFloat(record.get("mRecall"));
float mIoU = Float.parseFloat(record.get("mIoU"));
float mAcc = Float.parseFloat(record.get("mAcc"));
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
float changed_precision = Float.parseFloat(record.get("changed_precision"));
float changed_recall = Float.parseFloat(record.get("changed_recall"));
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
for (CSVRecord record : parser) { batchArgs.add(
new Object[] {
modelInfo.getModelId(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
}
int epoch = Integer.parseInt(record.get("Epoch")); modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
float aAcc = Float.parseFloat(record.get("aAcc"));
float mFscore = Float.parseFloat(record.get("mFscore"));
float mPrecision = Float.parseFloat(record.get("mPrecision"));
float mRecall = Float.parseFloat(record.get("mRecall"));
float mIoU = Float.parseFloat(record.get("mIoU"));
float mAcc = Float.parseFloat(record.get("mAcc"));
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
float changed_precision = Float.parseFloat(record.get("changed_precision"));
float changed_recall = Float.parseFloat(record.get("changed_recall"));
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
batchArgs.add( } catch (IOException e) {
new Object[] { throw new RuntimeException(e);
modelIds.getFirst(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
} }
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs); modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
} catch (IOException e) {
throw new RuntimeException(e);
} }
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
} }
} }