feat/training_260202 #49
@@ -155,6 +155,17 @@ public class ModelTrainMngDto {
|
||||
ModelConfig modelConfig;
|
||||
}
|
||||
|
||||
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class UpdateReq {
|
||||
|
||||
private String requestPath;
|
||||
private String responsePath;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class TrainingDataset {
|
||||
|
||||
@@ -13,6 +13,8 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -31,6 +33,7 @@ public class ModelTrainMngService {
|
||||
private final TrainJobService trainJobService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final HyperParamCoreService hyperParamCoreService;
|
||||
private final TmpDatasetService tmpDatasetService;
|
||||
|
||||
/**
|
||||
* 모델학습 조회
|
||||
@@ -90,6 +93,22 @@ public class ModelTrainMngService {
|
||||
// 모델 config 저장
|
||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().replace("-", "");
|
||||
|
||||
List<String> uids =
|
||||
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 저장 다 끝난 뒤에 job enqueue
|
||||
if (Boolean.TRUE.equals(req.getIsStart())) {
|
||||
trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT)
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.kamco.cd.training.model.service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
// 환경에 맞게 yml로 빼는 걸 추천
|
||||
private final Path BASE = Paths.get(requestDir);
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public Path buildTmpDatasetSymlink(String uid, List<String> uids) throws IOException {
|
||||
Path tmp = BASE.resolve("tmp").resolve(uid);
|
||||
|
||||
// mkdir -p "$TMP"/train/{input1,input2,label} ...
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
Files.createDirectories(tmp.resolve(type).resolve(part));
|
||||
}
|
||||
}
|
||||
|
||||
for (String id : uids) {
|
||||
Path srcRoot = BASE.resolve(id);
|
||||
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
|
||||
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
|
||||
if (!Files.isDirectory(srcDir)) continue;
|
||||
|
||||
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||
for (Path f : stream) {
|
||||
if (!Files.isRegularFile(f)) continue;
|
||||
|
||||
String dstName = id + "__" + f.getFileName();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||
|
||||
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
|
||||
if (Files.exists(dst)) continue;
|
||||
|
||||
// ln -s "$f" "$dst" 와 동일
|
||||
Files.createSymbolicLink(dst, f.toAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}", tmp);
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
@@ -149,6 +149,26 @@ public class ModelTrainMngCoreService {
|
||||
modelDatasetRepository.save(datasetEntity);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습모델 수정
|
||||
*
|
||||
* @param modelId
|
||||
* @param req
|
||||
*/
|
||||
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||
entity.setRequestPath(req.getRequestPath());
|
||||
}
|
||||
|
||||
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
|
||||
entity.setRequestPath(req.getResponsePath());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 데이터셋 mapping 테이블 저장
|
||||
*
|
||||
@@ -467,4 +487,14 @@ public class ModelTrainMngCoreService {
|
||||
|
||||
entity.setBestEpoch(epoch);
|
||||
}
|
||||
|
||||
/**
|
||||
* 데이터셋 uid 조회
|
||||
*
|
||||
* @param datasetIds
|
||||
* @return
|
||||
*/
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return datasetRepository.findDatasetUid(datasetIds);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,4 +22,6 @@ public interface DatasetRepositoryCustom {
|
||||
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
|
||||
|
||||
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
||||
|
||||
List<String> findDatasetUid(List<Long> datasetIds);
|
||||
}
|
||||
|
||||
@@ -242,4 +242,9 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
|
||||
.where(dataset.uid.eq(mngRegDto.getUid()))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
.select(
|
||||
Projections.constructor(
|
||||
TrainRunRequest.class,
|
||||
modelMasterEntity.uuid, // datasetFolder
|
||||
modelMasterEntity.requestPath, // datasetFolder
|
||||
modelMasterEntity.uuid, // outputFolder
|
||||
modelHyperParamEntity.inputSize,
|
||||
modelHyperParamEntity.cropSize,
|
||||
|
||||
@@ -221,7 +221,7 @@ public class DockerTrainService {
|
||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||
|
||||
// ===== 기본 파라미터 =====
|
||||
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
@@ -281,8 +281,8 @@ public class DockerTrainService {
|
||||
if (value == null) return;
|
||||
String s = String.valueOf(value).trim();
|
||||
if (s.isEmpty()) return;
|
||||
c.add(key);
|
||||
c.add(s);
|
||||
|
||||
c.add(key + "=" + s);
|
||||
}
|
||||
|
||||
/** 컨테이너 강제 종료 및 제거 */
|
||||
|
||||
Reference in New Issue
Block a user