feat/training_260202 #49
@@ -92,9 +92,8 @@ public class ModelTrainMngApiController {
|
|||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@PostMapping
|
@PostMapping
|
||||||
public ApiResponseDto<String> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
|
public ApiResponseDto<UUID> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
|
||||||
modelTrainMngService.createModelTrain(req);
|
return ApiResponseDto.ok(modelTrainMngService.createModelTrain(req));
|
||||||
return ApiResponseDto.ok("ok");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
|
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
|
||||||
|
|||||||
@@ -155,6 +155,17 @@ public class ModelTrainMngDto {
|
|||||||
ModelConfig modelConfig;
|
ModelConfig modelConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public static class UpdateReq {
|
||||||
|
|
||||||
|
private String requestPath;
|
||||||
|
private String responsePath;
|
||||||
|
}
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
public static class TrainingDataset {
|
public static class TrainingDataset {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
|||||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
import com.kamco.cd.training.train.service.TrainJobService;
|
import com.kamco.cd.training.train.service.TrainJobService;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
@@ -31,6 +33,7 @@ public class ModelTrainMngService {
|
|||||||
private final TrainJobService trainJobService;
|
private final TrainJobService trainJobService;
|
||||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
private final HyperParamCoreService hyperParamCoreService;
|
private final HyperParamCoreService hyperParamCoreService;
|
||||||
|
private final TmpDatasetService tmpDatasetService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 모델학습 조회
|
* 모델학습 조회
|
||||||
@@ -59,7 +62,7 @@ public class ModelTrainMngService {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void createModelTrain(ModelTrainMngDto.AddReq req) {
|
public UUID createModelTrain(ModelTrainMngDto.AddReq req) {
|
||||||
HyperParam hyperParam = req.getHyperParam();
|
HyperParam hyperParam = req.getHyperParam();
|
||||||
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||||
|
|
||||||
@@ -78,7 +81,10 @@ public class ModelTrainMngService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 모델학습 테이블 저장
|
// 모델학습 테이블 저장
|
||||||
Long modelId = modelTrainMngCoreService.saveModel(req);
|
ModelTrainMngDto.Basic modelDto = modelTrainMngCoreService.saveModel(req);
|
||||||
|
|
||||||
|
Long modelId = modelDto.getId();
|
||||||
|
UUID modelUuid = modelDto.getUuid();
|
||||||
|
|
||||||
// 모델학습 데이터셋 저장
|
// 모델학습 데이터셋 저장
|
||||||
modelTrainMngCoreService.saveModelDataset(modelId, req);
|
modelTrainMngCoreService.saveModelDataset(modelId, req);
|
||||||
@@ -90,10 +96,23 @@ public class ModelTrainMngService {
|
|||||||
// 모델 config 저장
|
// 모델 config 저장
|
||||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||||
|
|
||||||
// 저장 다 끝난 뒤에 job enqueue
|
UUID tmpUuid = UUID.randomUUID();
|
||||||
if (Boolean.TRUE.equals(req.getIsStart())) {
|
String raw = tmpUuid.toString().replace("-", "");
|
||||||
trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT)
|
|
||||||
|
List<String> uids =
|
||||||
|
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 데이터셋 심볼링크 생성
|
||||||
|
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||||
|
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||||
|
updateReq.setRequestPath(path.toString());
|
||||||
|
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return modelUuid;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package com.kamco.cd.training.model.service;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.*;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class TmpDatasetService {
|
||||||
|
|
||||||
|
@Value("${train.requestDir}")
|
||||||
|
private String requestDir;
|
||||||
|
|
||||||
|
// 환경에 맞게 yml로 빼는 걸 추천
|
||||||
|
private final Path BASE = Paths.get(requestDir);
|
||||||
|
|
||||||
|
@Transactional(readOnly = true)
|
||||||
|
public Path buildTmpDatasetSymlink(String uid, List<String> uids) throws IOException {
|
||||||
|
Path tmp = BASE.resolve("tmp").resolve(uid);
|
||||||
|
|
||||||
|
// mkdir -p "$TMP"/train/{input1,input2,label} ...
|
||||||
|
for (String type : List.of("train", "val")) {
|
||||||
|
for (String part : List.of("input1", "input2", "label")) {
|
||||||
|
Files.createDirectories(tmp.resolve(type).resolve(part));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (String id : uids) {
|
||||||
|
Path srcRoot = BASE.resolve(id);
|
||||||
|
|
||||||
|
for (String type : List.of("train", "val")) {
|
||||||
|
for (String part : List.of("input1", "input2", "label")) {
|
||||||
|
|
||||||
|
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||||
|
|
||||||
|
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
|
||||||
|
if (!Files.isDirectory(srcDir)) continue;
|
||||||
|
|
||||||
|
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||||
|
for (Path f : stream) {
|
||||||
|
if (!Files.isRegularFile(f)) continue;
|
||||||
|
|
||||||
|
String dstName = id + "__" + f.getFileName();
|
||||||
|
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||||
|
|
||||||
|
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
|
||||||
|
if (Files.exists(dst)) continue;
|
||||||
|
|
||||||
|
// ln -s "$f" "$dst" 와 동일
|
||||||
|
Files.createSymbolicLink(dst, f.toAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("tmp dataset created: {}", tmp);
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -18,7 +19,7 @@ public class ModelTestMetricsJobCoreService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test 로직 시작
|
// Test 로직 시작
|
||||||
public List<Long> getTestMetricSaveNotYetModelIds() {
|
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||||
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -12,7 +13,7 @@ public class ModelTrainMetricsJobCoreService {
|
|||||||
|
|
||||||
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
|
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
|
||||||
|
|
||||||
public List<Long> getTrainMetricSaveNotYetModelIds() {
|
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||||
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ public class ModelTrainMngCoreService {
|
|||||||
* @param addReq
|
* @param addReq
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
|
public ModelTrainMngDto.Basic saveModel(ModelTrainMngDto.AddReq addReq) {
|
||||||
ModelMasterEntity entity = new ModelMasterEntity();
|
ModelMasterEntity entity = new ModelMasterEntity();
|
||||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||||
|
|
||||||
@@ -117,7 +117,12 @@ public class ModelTrainMngCoreService {
|
|||||||
|
|
||||||
entity.setCreatedUid(userUtil.getId());
|
entity.setCreatedUid(userUtil.getId());
|
||||||
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
||||||
return resultEntity.getId();
|
|
||||||
|
ModelTrainMngDto.Basic result = new ModelTrainMngDto.Basic();
|
||||||
|
result.setId(resultEntity.getId());
|
||||||
|
result.setUuid(resultEntity.getUuid());
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -149,6 +154,26 @@ public class ModelTrainMngCoreService {
|
|||||||
modelDatasetRepository.save(datasetEntity);
|
modelDatasetRepository.save(datasetEntity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습모델 수정
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param req
|
||||||
|
*/
|
||||||
|
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||||
|
ModelMasterEntity entity =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||||
|
entity.setRequestPath(req.getRequestPath());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
|
||||||
|
entity.setRequestPath(req.getResponsePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 모델 데이터셋 mapping 테이블 저장
|
* 모델 데이터셋 mapping 테이블 저장
|
||||||
*
|
*
|
||||||
@@ -467,4 +492,14 @@ public class ModelTrainMngCoreService {
|
|||||||
|
|
||||||
entity.setBestEpoch(epoch);
|
entity.setBestEpoch(epoch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 데이터셋 uid 조회
|
||||||
|
*
|
||||||
|
* @param datasetIds
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||||
|
return datasetRepository.findDatasetUid(datasetIds);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,6 +106,12 @@ public class ModelMasterEntity {
|
|||||||
@Column(name = "best_epoch")
|
@Column(name = "best_epoch")
|
||||||
private Integer bestEpoch;
|
private Integer bestEpoch;
|
||||||
|
|
||||||
|
@Column(name = "request_path")
|
||||||
|
private String requestPath;
|
||||||
|
|
||||||
|
@Column(name = "response_path")
|
||||||
|
private String responsePath;
|
||||||
|
|
||||||
public ModelTrainMngDto.Basic toDto() {
|
public ModelTrainMngDto.Basic toDto() {
|
||||||
return new ModelTrainMngDto.Basic(
|
return new ModelTrainMngDto.Basic(
|
||||||
this.id,
|
this.id,
|
||||||
|
|||||||
@@ -22,4 +22,6 @@ public interface DatasetRepositoryCustom {
|
|||||||
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
|
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
|
||||||
|
|
||||||
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
||||||
|
|
||||||
|
List<String> findDatasetUid(List<Long> datasetIds);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,4 +242,9 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
|
|||||||
.where(dataset.uid.eq(mngRegDto.getUid()))
|
.where(dataset.uid.eq(mngRegDto.getUid()))
|
||||||
.fetchOne();
|
.fetchOne();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||||
|
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
|||||||
.select(
|
.select(
|
||||||
Projections.constructor(
|
Projections.constructor(
|
||||||
TrainRunRequest.class,
|
TrainRunRequest.class,
|
||||||
modelMasterEntity.uuid, // datasetFolder
|
modelMasterEntity.requestPath, // datasetFolder
|
||||||
modelMasterEntity.uuid, // outputFolder
|
modelMasterEntity.uuid, // outputFolder
|
||||||
modelHyperParamEntity.inputSize,
|
modelHyperParamEntity.inputSize,
|
||||||
modelHyperParamEntity.cropSize,
|
modelHyperParamEntity.cropSize,
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.train;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ModelTestMetricsJobRepositoryCustom {
|
public interface ModelTestMetricsJobRepositoryCustom {
|
||||||
|
|
||||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||||
|
|
||||||
List<Long> getTestMetricSaveNotYetModelIds();
|
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
||||||
|
|
||||||
void insertModelMetricsTest(List<Object[]> batchArgs);
|
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
|||||||
|
|
||||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
|
import com.querydsl.core.types.Projections;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||||
@@ -36,9 +38,11 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Long> getTestMetricSaveNotYetModelIds() {
|
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||||
return queryFactory
|
return queryFactory
|
||||||
.select(modelMasterEntity.id)
|
.select(
|
||||||
|
Projections.constructor(
|
||||||
|
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||||
.from(modelMasterEntity)
|
.from(modelMasterEntity)
|
||||||
.where(
|
.where(
|
||||||
modelMasterEntity.step2EndDttm.isNotNull(),
|
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.train;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ModelTrainMetricsJobRepositoryCustom {
|
public interface ModelTrainMetricsJobRepositoryCustom {
|
||||||
|
|
||||||
List<Long> getTrainMetricSaveNotYetModelIds();
|
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
|
||||||
|
|
||||||
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMast
|
|||||||
|
|
||||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
|
import com.querydsl.core.types.Projections;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||||
@@ -23,9 +25,11 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Long> getTrainMetricSaveNotYetModelIds() {
|
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||||
return queryFactory
|
return queryFactory
|
||||||
.select(modelMasterEntity.id)
|
.select(
|
||||||
|
Projections.constructor(
|
||||||
|
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||||
.from(modelMasterEntity)
|
.from(modelMasterEntity)
|
||||||
.where(
|
.where(
|
||||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
modelMasterEntity.step1EndDttm.isNotNull(),
|
||||||
@@ -41,7 +45,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
|||||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||||
String sql =
|
String sql =
|
||||||
"""
|
"""
|
||||||
insert into tb_model_matrics_train
|
insert into tb_model_metrics_train
|
||||||
(model_id, epoch, iteration, loss, lr, duration_time)
|
(model_id, epoch, iteration, loss, lr, duration_time)
|
||||||
values (?, ?, ?, ?, ?, ?)
|
values (?, ?, ?, ?, ?, ?)
|
||||||
""";
|
""";
|
||||||
@@ -66,7 +70,7 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
|||||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||||
String sql =
|
String sql =
|
||||||
"""
|
"""
|
||||||
insert into tb_model_matrics_validation
|
insert into tb_model_metrics_validation
|
||||||
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
|
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
|
||||||
unchanged_fscore, unchanged_precision, unchanged_recall
|
unchanged_fscore, unchanged_precision, unchanged_recall
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.kamco.cd.training.train.dto;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
public class ModelTrainMetricsDto {
|
||||||
|
|
||||||
|
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public static class ResponsePathDto {
|
||||||
|
|
||||||
|
private Long modelId;
|
||||||
|
private String responsePath;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -221,7 +221,7 @@ public class DockerTrainService {
|
|||||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||||
|
|
||||||
// ===== 기본 파라미터 =====
|
// ===== 기본 파라미터 =====
|
||||||
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
|
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||||
addArg(c, "--output-folder", req.getOutputFolder());
|
addArg(c, "--output-folder", req.getOutputFolder());
|
||||||
addArg(c, "--input-size", req.getInputSize());
|
addArg(c, "--input-size", req.getInputSize());
|
||||||
addArg(c, "--crop-size", req.getCropSize());
|
addArg(c, "--crop-size", req.getCropSize());
|
||||||
@@ -281,8 +281,8 @@ public class DockerTrainService {
|
|||||||
if (value == null) return;
|
if (value == null) return;
|
||||||
String s = String.valueOf(value).trim();
|
String s = String.valueOf(value).trim();
|
||||||
if (s.isEmpty()) return;
|
if (s.isEmpty()) return;
|
||||||
c.add(key);
|
|
||||||
c.add(s);
|
c.add(key + "=" + s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 컨테이너 강제 종료 및 제거 */
|
/** 컨테이너 강제 종료 및 제거 */
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
@@ -35,66 +36,67 @@ public class ModelTestMetricsJobService {
|
|||||||
return "local".equalsIgnoreCase(profile);
|
return "local".equalsIgnoreCase(profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Scheduled(cron = "0 0/10 * * * *")
|
// @Scheduled(cron = "0 * * * * *")
|
||||||
public void findTestValidMetricCsvFiles() {
|
public void findTestValidMetricCsvFiles() {
|
||||||
if (isLocalProfile()) {
|
// if (isLocalProfile()) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
|
|
||||||
List<Long> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTestMetricsJobCoreService
|
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||||
.getTestMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
|
||||||
|
|
||||||
if (modelIds.isEmpty()) {
|
if (modelIds.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
String localPath = "C:\\data\\upload\\test.csv";
|
for (ResponsePathDto modelInfo : modelIds) {
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
log.info("### localPath={}", localPath);
|
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
try (BufferedReader reader =
|
||||||
|
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
|
|
||||||
String model = record.get("model");
|
for (CSVRecord record : parser) {
|
||||||
long TP = Long.parseLong(record.get("TP"));
|
|
||||||
long FP = Long.parseLong(record.get("FP"));
|
|
||||||
long FN = Long.parseLong(record.get("FN"));
|
|
||||||
float precision = Float.parseFloat(record.get("precision"));
|
|
||||||
float recall = Float.parseFloat(record.get("recall"));
|
|
||||||
float f1_score = Float.parseFloat(record.get("f1_score"));
|
|
||||||
float accuracy = Float.parseFloat(record.get("accuracy"));
|
|
||||||
float iou = Float.parseFloat(record.get("iou"));
|
|
||||||
long detection_count = Long.parseLong(record.get("detection_count"));
|
|
||||||
long gt_count = Long.parseLong(record.get("gt_count"));
|
|
||||||
|
|
||||||
batchArgs.add(
|
String model = record.get("model");
|
||||||
new Object[] {
|
long TP = Long.parseLong(record.get("TP"));
|
||||||
modelIds.getFirst(),
|
long FP = Long.parseLong(record.get("FP"));
|
||||||
model,
|
long FN = Long.parseLong(record.get("FN"));
|
||||||
TP,
|
float precision = Float.parseFloat(record.get("precision"));
|
||||||
FP,
|
float recall = Float.parseFloat(record.get("recall"));
|
||||||
FN,
|
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||||
precision,
|
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||||
recall,
|
float iou = Float.parseFloat(record.get("iou"));
|
||||||
f1_score,
|
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||||
accuracy,
|
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||||
iou,
|
|
||||||
detection_count,
|
batchArgs.add(
|
||||||
gt_count
|
new Object[] {
|
||||||
});
|
modelInfo.getModelId(),
|
||||||
|
model,
|
||||||
|
TP,
|
||||||
|
FP,
|
||||||
|
FN,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
f1_score,
|
||||||
|
accuracy,
|
||||||
|
iou,
|
||||||
|
detection_count,
|
||||||
|
gt_count
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step2");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
@@ -35,96 +36,97 @@ public class ModelTrainMetricsJobService {
|
|||||||
return "local".equalsIgnoreCase(profile);
|
return "local".equalsIgnoreCase(profile);
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Scheduled(cron = "0 0/10 * * * *")
|
// @Scheduled(cron = "0 * * * * *")
|
||||||
public void findTrainValidMetricCsvFiles() {
|
public void findTrainValidMetricCsvFiles() {
|
||||||
if (isLocalProfile()) {
|
// if (isLocalProfile()) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
|
|
||||||
List<Long> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTrainMetricsJobCoreService
|
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||||
.getTrainMetricSaveNotYetModelIds(); // TODO: uid, uuid ? 가져오기로 해야함
|
|
||||||
|
|
||||||
if (modelIds.isEmpty()) {
|
if (modelIds.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
String localPath = "C:\\data\\upload\\train.csv";
|
for (ResponsePathDto modelInfo : modelIds) {
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(localPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
log.info("### localPath={}", localPath);
|
String trainPath = modelInfo.getResponsePath() + "/metrics/train.csv";
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
try (BufferedReader reader =
|
||||||
|
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
for (CSVRecord record : parser) {
|
||||||
long iteration = Long.parseLong(record.get("Iteration"));
|
|
||||||
double Loss = Double.parseDouble(record.get("Loss"));
|
|
||||||
double LR = Double.parseDouble(record.get("LR"));
|
|
||||||
float time = Float.parseFloat(record.get("Time"));
|
|
||||||
|
|
||||||
batchArgs.add(new Object[] {modelIds.getFirst(), epoch, iteration, Loss, LR, time});
|
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
|
||||||
|
long iteration = Long.parseLong(record.get("Iteration"));
|
||||||
|
double Loss = Double.parseDouble(record.get("Loss"));
|
||||||
|
double LR = Double.parseDouble(record.get("LR"));
|
||||||
|
float time = Float.parseFloat(record.get("Time"));
|
||||||
|
|
||||||
|
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
|
||||||
|
}
|
||||||
|
|
||||||
|
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
|
||||||
|
try (BufferedReader reader =
|
||||||
|
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
} catch (IOException e) {
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
String validationPath = "C:\\data\\upload\\val.csv";
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
log.info("### validationPath={}", validationPath);
|
for (CSVRecord record : parser) {
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||||
|
float aAcc = Float.parseFloat(record.get("aAcc"));
|
||||||
|
float mFscore = Float.parseFloat(record.get("mFscore"));
|
||||||
|
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
||||||
|
float mRecall = Float.parseFloat(record.get("mRecall"));
|
||||||
|
float mIoU = Float.parseFloat(record.get("mIoU"));
|
||||||
|
float mAcc = Float.parseFloat(record.get("mAcc"));
|
||||||
|
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
||||||
|
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
||||||
|
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
||||||
|
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
||||||
|
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
||||||
|
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
batchArgs.add(
|
||||||
|
new Object[] {
|
||||||
|
modelInfo.getModelId(),
|
||||||
|
epoch,
|
||||||
|
aAcc,
|
||||||
|
mFscore,
|
||||||
|
mPrecision,
|
||||||
|
mRecall,
|
||||||
|
mIoU,
|
||||||
|
mAcc,
|
||||||
|
changed_fscore,
|
||||||
|
changed_precision,
|
||||||
|
changed_recall,
|
||||||
|
unchanged_fscore,
|
||||||
|
unchanged_precision,
|
||||||
|
unchanged_recall
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
||||||
float aAcc = Float.parseFloat(record.get("aAcc"));
|
|
||||||
float mFscore = Float.parseFloat(record.get("mFscore"));
|
|
||||||
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
|
||||||
float mRecall = Float.parseFloat(record.get("mRecall"));
|
|
||||||
float mIoU = Float.parseFloat(record.get("mIoU"));
|
|
||||||
float mAcc = Float.parseFloat(record.get("mAcc"));
|
|
||||||
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
|
||||||
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
|
||||||
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
|
||||||
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
|
||||||
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
|
||||||
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
|
||||||
|
|
||||||
batchArgs.add(
|
} catch (IOException e) {
|
||||||
new Object[] {
|
throw new RuntimeException(e);
|
||||||
modelIds.getFirst(),
|
|
||||||
epoch,
|
|
||||||
aAcc,
|
|
||||||
mFscore,
|
|
||||||
mPrecision,
|
|
||||||
mRecall,
|
|
||||||
mIoU,
|
|
||||||
mAcc,
|
|
||||||
changed_fscore,
|
|
||||||
changed_precision,
|
|
||||||
changed_recall,
|
|
||||||
unchanged_fscore,
|
|
||||||
unchanged_precision,
|
|
||||||
unchanged_recall
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
||||||
|
modelInfo.getModelId(), "step1");
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelIds.getFirst(), "step1");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user