feat/training_260202 #49

Merged
teddy merged 4 commits from feat/training_260202 into develop 2026-02-12 15:44:11 +09:00
19 changed files with 317 additions and 137 deletions

View File

@@ -92,9 +92,8 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@PostMapping @PostMapping
public ApiResponseDto<String> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) { public ApiResponseDto<UUID> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
modelTrainMngService.createModelTrain(req); return ApiResponseDto.ok(modelTrainMngService.createModelTrain(req));
return ApiResponseDto.ok("ok");
} }
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API") @Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")

View File

@@ -155,6 +155,17 @@ public class ModelTrainMngDto {
ModelConfig modelConfig; ModelConfig modelConfig;
} }
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class UpdateReq {
private String requestPath;
private String responsePath;
}
@Getter @Getter
@Setter @Setter
public static class TrainingDataset { public static class TrainingDataset {

View File

@@ -13,6 +13,8 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.service.TrainJobService; import com.kamco.cd.training.train.service.TrainJobService;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -31,6 +33,7 @@ public class ModelTrainMngService {
private final TrainJobService trainJobService; private final TrainJobService trainJobService;
private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService;
private final HyperParamCoreService hyperParamCoreService; private final HyperParamCoreService hyperParamCoreService;
private final TmpDatasetService tmpDatasetService;
/** /**
* 모델학습 조회 * 모델학습 조회
@@ -59,7 +62,7 @@ public class ModelTrainMngService {
* @return * @return
*/ */
@Transactional @Transactional
public void createModelTrain(ModelTrainMngDto.AddReq req) { public UUID createModelTrain(ModelTrainMngDto.AddReq req) {
HyperParam hyperParam = req.getHyperParam(); HyperParam hyperParam = req.getHyperParam();
HyperParamDto.Basic hyper = new HyperParamDto.Basic(); HyperParamDto.Basic hyper = new HyperParamDto.Basic();
@@ -78,7 +81,10 @@ public class ModelTrainMngService {
} }
// 모델학습 테이블 저장 // 모델학습 테이블 저장
Long modelId = modelTrainMngCoreService.saveModel(req); ModelTrainMngDto.Basic modelDto = modelTrainMngCoreService.saveModel(req);
Long modelId = modelDto.getId();
UUID modelUuid = modelDto.getUuid();
// 모델학습 데이터셋 저장 // 모델학습 데이터셋 저장
modelTrainMngCoreService.saveModelDataset(modelId, req); modelTrainMngCoreService.saveModelDataset(modelId, req);
@@ -90,10 +96,23 @@ public class ModelTrainMngService {
// 모델 config 저장 // 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig()); modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
// 저장 다 끝난 뒤에 job enqueue UUID tmpUuid = UUID.randomUUID();
if (Boolean.TRUE.equals(req.getIsStart())) { String raw = tmpUuid.toString().replace("-", "");
trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT)
List<String> uids =
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);
} }
return modelUuid;
} }
/** /**

View File

@@ -0,0 +1,66 @@
package com.kamco.cd.training.model.service;
import java.io.IOException;
import java.nio.file.*;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Slf4j
@Service
@RequiredArgsConstructor
public class TmpDatasetService {
@Value("${train.requestDir}")
private String requestDir;
// 환경에 맞게 yml로 빼는 걸 추천
private final Path BASE = Paths.get(requestDir);
@Transactional(readOnly = true)
public Path buildTmpDatasetSymlink(String uid, List<String> uids) throws IOException {
Path tmp = BASE.resolve("tmp").resolve(uid);
// mkdir -p "$TMP"/train/{input1,input2,label} ...
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Files.createDirectories(tmp.resolve(type).resolve(part));
}
}
for (String id : uids) {
Path srcRoot = BASE.resolve(id);
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
if (!Files.isDirectory(srcDir)) continue;
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
for (Path f : stream) {
if (!Files.isRegularFile(f)) continue;
String dstName = id + "__" + f.getFileName();
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
if (Files.exists(dst)) continue;
// ln -s "$f" "$dst" 와 동일
Files.createSymbolicLink(dst, f.toAbsolutePath());
}
}
}
}
}
log.info("tmp dataset created: {}", tmp);
return tmp;
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -18,7 +19,7 @@ public class ModelTestMetricsJobCoreService {
} }
// Test 로직 시작 // Test 로직 시작
public List<Long> getTestMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds(); return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -12,7 +13,7 @@ public class ModelTrainMetricsJobCoreService {
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository; private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
public List<Long> getTrainMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds(); return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
} }

View File

@@ -78,7 +78,7 @@ public class ModelTrainMngCoreService {
* @param addReq * @param addReq
* @return * @return
*/ */
public Long saveModel(ModelTrainMngDto.AddReq addReq) { public ModelTrainMngDto.Basic saveModel(ModelTrainMngDto.AddReq addReq) {
ModelMasterEntity entity = new ModelMasterEntity(); ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity(); ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
@@ -117,7 +117,12 @@ public class ModelTrainMngCoreService {
entity.setCreatedUid(userUtil.getId()); entity.setCreatedUid(userUtil.getId());
ModelMasterEntity resultEntity = modelMngRepository.save(entity); ModelMasterEntity resultEntity = modelMngRepository.save(entity);
return resultEntity.getId();
ModelTrainMngDto.Basic result = new ModelTrainMngDto.Basic();
result.setId(resultEntity.getId());
result.setUuid(resultEntity.getUuid());
return result;
} }
/** /**
@@ -149,6 +154,26 @@ public class ModelTrainMngCoreService {
modelDatasetRepository.save(datasetEntity); modelDatasetRepository.save(datasetEntity);
} }
/**
* 학습모델 수정
*
* @param modelId
* @param req
*/
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath());
}
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
entity.setRequestPath(req.getResponsePath());
}
}
/** /**
* 모델 데이터셋 mapping 테이블 저장 * 모델 데이터셋 mapping 테이블 저장
* *
@@ -467,4 +492,14 @@ public class ModelTrainMngCoreService {
entity.setBestEpoch(epoch); entity.setBestEpoch(epoch);
} }
/**
* 데이터셋 uid 조회
*
* @param datasetIds
* @return
*/
public List<String> findDatasetUid(List<Long> datasetIds) {
return datasetRepository.findDatasetUid(datasetIds);
}
} }

View File

@@ -106,6 +106,12 @@ public class ModelMasterEntity {
@Column(name = "best_epoch") @Column(name = "best_epoch")
private Integer bestEpoch; private Integer bestEpoch;
@Column(name = "request_path")
private String requestPath;
@Column(name = "response_path")
private String responsePath;
public ModelTrainMngDto.Basic toDto() { public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic( return new ModelTrainMngDto.Basic(
this.id, this.id,

View File

@@ -22,4 +22,6 @@ public interface DatasetRepositoryCustom {
Long getDatasetMaxStage(int compareYyyy, int targetYyyy); Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
Long insertDatasetMngData(DatasetMngRegDto mngRegDto); Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
List<String> findDatasetUid(List<Long> datasetIds);
} }

View File

@@ -242,4 +242,9 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.where(dataset.uid.eq(mngRegDto.getUid())) .where(dataset.uid.eq(mngRegDto.getUid()))
.fetchOne(); .fetchOne();
} }
@Override
public List<String> findDatasetUid(List<Long> datasetIds) {
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
}
} }

View File

@@ -94,7 +94,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
.select( .select(
Projections.constructor( Projections.constructor(
TrainRunRequest.class, TrainRunRequest.class,
modelMasterEntity.uuid, // datasetFolder modelMasterEntity.requestPath, // datasetFolder
modelMasterEntity.uuid, // outputFolder modelMasterEntity.uuid, // outputFolder
modelHyperParamEntity.inputSize, modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize, modelHyperParamEntity.cropSize,

View File

@@ -1,12 +1,13 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom { public interface ModelTestMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
List<Long> getTestMetricSaveNotYetModelIds(); List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs); void insertModelMetricsTest(List<Object[]> batchArgs);
} }

View File

@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
@@ -36,9 +38,11 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
} }
@Override @Override
public List<Long> getTestMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return queryFactory return queryFactory
.select(modelMasterEntity.id) .select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity) .from(modelMasterEntity)
.where( .where(
modelMasterEntity.step2EndDttm.isNotNull(), modelMasterEntity.step2EndDttm.isNotNull(),

View File

@@ -1,10 +1,11 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List; import java.util.List;
public interface ModelTrainMetricsJobRepositoryCustom { public interface ModelTrainMetricsJobRepositoryCustom {
List<Long> getTrainMetricSaveNotYetModelIds(); List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
void insertModelMetricsTrain(List<Object[]> batchArgs); void insertModelMetricsTrain(List<Object[]> batchArgs);

View File

@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
@@ -23,9 +25,11 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
} }
@Override @Override
public List<Long> getTrainMetricSaveNotYetModelIds() { public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return queryFactory return queryFactory
.select(modelMasterEntity.id) .select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity) .from(modelMasterEntity)
.where( .where(
modelMasterEntity.step1EndDttm.isNotNull(), modelMasterEntity.step1EndDttm.isNotNull(),
@@ -41,7 +45,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
public void insertModelMetricsTrain(List<Object[]> batchArgs) { public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql = String sql =
""" """
insert into tb_model_matrics_train insert into tb_model_metrics_train
(model_id, epoch, iteration, loss, lr, duration_time) (model_id, epoch, iteration, loss, lr, duration_time)
values (?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?)
"""; """;
@@ -66,7 +70,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
public void insertModelMetricsValidation(List<Object[]> batchArgs) { public void insertModelMetricsValidation(List<Object[]> batchArgs) {
String sql = String sql =
""" """
insert into tb_model_matrics_validation insert into tb_model_metrics_validation
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall, (model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
unchanged_fscore, unchanged_precision, unchanged_recall unchanged_fscore, unchanged_precision, unchanged_recall
) )

View File

@@ -0,0 +1,21 @@
package com.kamco.cd.training.train.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class ModelTrainMetricsDto {
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ResponsePathDto {
private Long modelId;
private String responsePath;
}
}

View File

@@ -221,7 +221,7 @@ public class DockerTrainService {
c.add("/workspace/change-detection-code/train_wrapper.py"); c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 ===== // ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/); addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder()); addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize()); addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize()); addArg(c, "--crop-size", req.getCropSize());
@@ -281,8 +281,8 @@ public class DockerTrainService {
if (value == null) return; if (value == null) return;
String s = String.valueOf(value).trim(); String s = String.valueOf(value).trim();
if (s.isEmpty()) return; if (s.isEmpty()) return;
c.add(key);
c.add(s); c.add(key + "=" + s);
} }
/** 컨테이너 강제 종료 및 제거 */ /** 컨테이너 강제 종료 및 제거 */

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -35,25 +36,25 @@ public class ModelTestMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 0/10 * * * *") // @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() { public void findTestValidMetricCsvFiles() {
if (isLocalProfile()) { // if (isLocalProfile()) {
return; // return;
} // }
List<Long> modelIds = List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
.getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
if (modelIds.isEmpty()) { if (modelIds.isEmpty()) {
return; return;
} }
String localPath = "C:\\data\\upload\\test.csv"; for (ResponsePathDto modelInfo : modelIds) {
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
log.info("### localPath={}", localPath);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>(); List<Object[]> batchArgs = new ArrayList<>();
@@ -74,7 +75,7 @@ public class ModelTestMetricsJobService {
batchArgs.add( batchArgs.add(
new Object[] { new Object[] {
modelIds.getFirst(), modelInfo.getModelId(),
model, model,
TP, TP,
FP, FP,
@@ -95,6 +96,7 @@ public class ModelTestMetricsJobService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2"); modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
}
} }
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -35,38 +36,38 @@ public class ModelTrainMetricsJobService {
return "local".equalsIgnoreCase(profile); return "local".equalsIgnoreCase(profile);
} }
// @Scheduled(cron = "0 0/10 * * * *") // @Scheduled(cron = "0 * * * * *")
public void findTrainValidMetricCsvFiles() { public void findTrainValidMetricCsvFiles() {
if (isLocalProfile()) { // if (isLocalProfile()) {
return; // return;
} // }
List<Long> modelIds = List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
if (modelIds.isEmpty()) { if (modelIds.isEmpty()) {
return; return;
} }
String localPath = "C:\\data\\upload\\train.csv"; for (ResponsePathDto modelInfo : modelIds) {
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) { String trainPath = modelInfo.getResponsePath() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
log.info("### localPath={}", localPath);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>(); List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) { for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch")); int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
long iteration = Long.parseLong(record.get("Iteration")); long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss")); double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR")); double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time")); float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time}); batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
} }
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs); modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
@@ -75,11 +76,10 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
String validationPath = "C:\\data\\upload\\val.csv"; String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
try (BufferedReader reader = try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) { Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
log.info("### validationPath={}", validationPath);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader); CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>(); List<Object[]> batchArgs = new ArrayList<>();
@@ -102,7 +102,7 @@ public class ModelTrainMetricsJobService {
batchArgs.add( batchArgs.add(
new Object[] { new Object[] {
modelIds.getFirst(), modelInfo.getModelId(),
epoch, epoch,
aAcc, aAcc,
mFscore, mFscore,
@@ -125,6 +125,8 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1"); modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
}
} }
} }