Compare commits
14 Commits
feat/dean/
...
c92426aefc
| Author | SHA1 | Date | |
|---|---|---|---|
| c92426aefc | |||
| 6185a18a7c | |||
| 1fb10830b9 | |||
| ae0d30e5da | |||
| ed48f697a4 | |||
| 6c865d26fd | |||
| 16e156b5b4 | |||
| 6a939118ff | |||
| 0c0ae16c2b | |||
| 953f95aed6 | |||
| 85633c8bab | |||
| 8b3940b446 | |||
| 201cfefb6b | |||
| 9958b0999a |
@@ -217,7 +217,7 @@ public class DatasetApiController {
|
||||
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
|
||||
@RequestBody @Valid DatasetDto.AddReq addReq) {
|
||||
|
||||
return ApiResponseDto.ok(datasetService.insertDataset(addReq));
|
||||
return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
|
||||
}
|
||||
|
||||
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
|
||||
|
||||
@@ -208,6 +208,13 @@ public class DatasetService {
|
||||
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
|
||||
|
||||
try {
|
||||
// 같은 uid 로 등록한 파일이 있는지 확인
|
||||
Long existsCnt =
|
||||
datasetCoreService.findDatasetByUidExistsCnt(addReq.getFileName().replace(".zip", ""));
|
||||
if (existsCnt > 0) {
|
||||
return new ResponseObj(ApiResponseCode.DUPLICATE_DATA, "이미 등록된 회차 데이터 파일입니다. 확인 부탁드립니다.");
|
||||
}
|
||||
|
||||
// 압축 해제
|
||||
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ public class ModelTrainMngDto {
|
||||
}
|
||||
}
|
||||
|
||||
public String getStep2StatusNAme() {
|
||||
public String getStep2StatusName() {
|
||||
if (this.step2Status == null || this.step2Status.isBlank()) return null;
|
||||
try {
|
||||
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
|
||||
|
||||
@@ -13,7 +13,6 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -102,9 +101,9 @@ public class ModelTrainMngService {
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
String tmpUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
updateReq.setRequestPath(tmpUid);
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
@@ -5,62 +5,139 @@ import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
// @Value("${train.docker.requestDir}")
|
||||
private String requestDir = "/home/kcomu/data";
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
// 환경에 맞게 yml로 빼는 걸 추천
|
||||
Path BASE = Paths.get(requestDir);
|
||||
log.info("========== buildTmpDatasetHardlink START ==========");
|
||||
log.info("uid={}", uid);
|
||||
log.info("datasetUids={}", datasetUids);
|
||||
log.info("requestDir(raw)={}", requestDir);
|
||||
|
||||
Path BASE = toPath(requestDir);
|
||||
Path tmp = BASE.resolve("tmp").resolve(uid);
|
||||
|
||||
// mkdir -p "$TMP"/train/{input1,input2,label} ...
|
||||
log.info("BASE={}", BASE);
|
||||
log.info("BASE exists? {}", Files.isDirectory(BASE));
|
||||
log.info("tmp={}", tmp);
|
||||
|
||||
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
|
||||
|
||||
// tmp 디렉토리 준비
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
Files.createDirectories(tmp.resolve(type).resolve(part));
|
||||
Path dir = tmp.resolve(type).resolve(part);
|
||||
Files.createDirectories(dir);
|
||||
log.info("createDirectories: {}", dir);
|
||||
}
|
||||
}
|
||||
|
||||
// 하드링크는 "같은 파일시스템"에서만 가능하므로 BASE/tmp가 같은 FS인지 미리 확인(권장)
|
||||
try {
|
||||
var baseStore = Files.getFileStore(BASE);
|
||||
var tmpStore = Files.getFileStore(tmp.getParent()); // BASE/tmp
|
||||
if (!baseStore.name().equals(tmpStore.name()) || !baseStore.type().equals(tmpStore.type())) {
|
||||
throw new IOException(
|
||||
"Hardlink requires same filesystem. baseStore="
|
||||
+ baseStore.name()
|
||||
+ "("
|
||||
+ baseStore.type()
|
||||
+ "), tmpStore="
|
||||
+ tmpStore.name()
|
||||
+ "("
|
||||
+ tmpStore.type()
|
||||
+ ")");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// FileStore 비교가 환경마다 애매할 수 있어서, 여기서는 경고만 주고 실제 createLink에서 최종 판단하게 둘 수도 있음.
|
||||
log.warn("FileStore check skipped/failed (will rely on createLink): {}", e.toString());
|
||||
}
|
||||
|
||||
for (String id : datasetUids) {
|
||||
Path srcRoot = BASE.resolve(id);
|
||||
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
|
||||
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
|
||||
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
|
||||
if (!Files.isDirectory(srcDir)) continue;
|
||||
if (!Files.isDirectory(srcDir)) {
|
||||
log.warn("SKIP (not directory): {}", srcDir);
|
||||
noDir++;
|
||||
continue;
|
||||
}
|
||||
|
||||
scannedDirs++;
|
||||
log.info("SCAN dir={}", srcDir);
|
||||
|
||||
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||
for (Path f : stream) {
|
||||
if (!Files.isRegularFile(f)) continue;
|
||||
if (!Files.isRegularFile(f)) {
|
||||
log.debug("skip non-regular file: {}", f);
|
||||
continue;
|
||||
}
|
||||
|
||||
regularFiles++;
|
||||
|
||||
String dstName = id + "__" + f.getFileName();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||
|
||||
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
|
||||
if (Files.exists(dst)) continue;
|
||||
// dst가 남아있으면 삭제(심볼릭링크든 파일이든)
|
||||
if (Files.exists(dst) || Files.isSymbolicLink(dst)) {
|
||||
Files.delete(dst);
|
||||
log.debug("deleted existing: {}", dst);
|
||||
}
|
||||
|
||||
// ln -s "$f" "$dst" 와 동일
|
||||
Files.createSymbolicLink(dst, f.toAbsolutePath());
|
||||
try {
|
||||
// 하드링크 생성 (dst가 새 파일로 생기지만 inode는 f와 동일)
|
||||
Files.createLink(dst, f);
|
||||
hardlinksMade++;
|
||||
log.debug("created hardlink: {} => {}", dst, f);
|
||||
} catch (IOException e) {
|
||||
// 여기서 바로 실패시키면 “tmp는 만들었는데 내용은 0개” 같은 상태를 방지할 수 있음
|
||||
log.error("FAILED create hardlink: {} => {}", dst, f, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hardlinksMade == 0) {
|
||||
throw new IOException(
|
||||
"No hardlinks created. regularFiles="
|
||||
+ regularFiles
|
||||
+ ", scannedDirs="
|
||||
+ scannedDirs
|
||||
+ ", noDir="
|
||||
+ noDir);
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}", tmp);
|
||||
return tmp;
|
||||
log.info(
|
||||
"summary: scannedDirs={}, noDir={}, regularFiles={}, hardlinksMade={}",
|
||||
scannedDirs,
|
||||
noDir,
|
||||
regularFiles,
|
||||
hardlinksMade);
|
||||
|
||||
return uid;
|
||||
}
|
||||
|
||||
private static Path toPath(String p) {
|
||||
if (p.startsWith("~/")) {
|
||||
return Paths.get(System.getProperty("user.home")).resolve(p.substring(2)).normalize();
|
||||
}
|
||||
return Paths.get(p).toAbsolutePath().normalize();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,4 +246,8 @@ public class DatasetCoreService
|
||||
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
|
||||
datasetObjRepository.insertDatasetValObj(objRegDto);
|
||||
}
|
||||
|
||||
public Long findDatasetByUidExistsCnt(String uid) {
|
||||
return datasetRepository.findDatasetByUidExistsCnt(uid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateEpoch(String containerName, Integer epoch) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findByContainerName(containerName)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
||||
|
||||
job.setCurrentEpoch(epoch);
|
||||
|
||||
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,4 +24,6 @@ public interface DatasetRepositoryCustom {
|
||||
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
||||
|
||||
List<String> findDatasetUid(List<Long> datasetIds);
|
||||
|
||||
Long findDatasetByUidExistsCnt(String uid);
|
||||
}
|
||||
|
||||
@@ -247,4 +247,13 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long findDatasetByUidExistsCnt(String uid) {
|
||||
return queryFactory
|
||||
.select(dataset.id.count())
|
||||
.from(dataset)
|
||||
.where(dataset.uid.eq(uid))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
|
||||
|
||||
@Override
|
||||
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
|
||||
queryFactory
|
||||
return queryFactory
|
||||
.select(modelDatasetMappEntity)
|
||||
.from(modelDatasetMappEntity)
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId));
|
||||
return List.of();
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,11 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
|
||||
if (req.getStatus() != null && !req.getStatus().isEmpty()) {
|
||||
builder.and(modelMasterEntity.statusCd.eq(req.getStatus()));
|
||||
builder.and(
|
||||
modelMasterEntity
|
||||
.step1State
|
||||
.eq(req.getStatus())
|
||||
.or(modelMasterEntity.step2State.eq(req.getStatus())));
|
||||
}
|
||||
|
||||
if (req.getModelNo() != null && !req.getModelNo().isEmpty()) {
|
||||
|
||||
@@ -42,7 +42,10 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
ResponsePathDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.responsePath,
|
||||
modelMasterEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||
|
||||
@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||
}
|
||||
|
||||
@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory
|
||||
.selectFrom(j)
|
||||
.where(j.containerName.eq(containerName))
|
||||
.orderBy(j.id.desc())
|
||||
.fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
ResponsePathDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.responsePath,
|
||||
modelMasterEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.util.UUID;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -17,5 +18,6 @@ public class ModelTrainMetricsDto {
|
||||
|
||||
private Long modelId;
|
||||
private String responsePath;
|
||||
private UUID uuid;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Log4j2
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@@ -43,6 +47,8 @@ public class DockerTrainService {
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
@@ -61,7 +67,11 @@ public class DockerTrainService {
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder logBuilder = new StringBuilder();
|
||||
|
||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
|
||||
|
||||
// 너무 잦은 업데이트 방지용
|
||||
AtomicInteger lastEpoch = new AtomicInteger(0);
|
||||
AtomicInteger lastIter = new AtomicInteger(0);
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
@@ -73,23 +83,40 @@ public class DockerTrainService {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
|
||||
// 1) 로그 누적
|
||||
synchronized (logBuilder) {
|
||||
logBuilder.append(line).append('\n');
|
||||
}
|
||||
|
||||
// 2) epoch 감지 + DB 업데이트
|
||||
Matcher m = epochPattern.matcher(line);
|
||||
if (m.find()) {
|
||||
int currentEpoch = Integer.parseInt(m.group(1));
|
||||
int totalEpoch = Integer.parseInt(m.group(2));
|
||||
int epoch = Integer.parseInt(m.group(1));
|
||||
int iter = Integer.parseInt(m.group(2));
|
||||
int totalIter = Integer.parseInt(m.group(3));
|
||||
|
||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
|
||||
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
|
||||
|
||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||
// currentEpoch, totalEpoch);
|
||||
// 쓰로틀링: 에폭 끝 or 10 iter마다
|
||||
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
|
||||
|
||||
// 중복 방지
|
||||
if (shouldUpdate) {
|
||||
int prevEpoch = lastEpoch.get();
|
||||
int prevIter = lastIter.get();
|
||||
if (epoch != prevEpoch || iter != prevIter) {
|
||||
lastEpoch.set(epoch);
|
||||
lastIter.set(iter);
|
||||
|
||||
log.info(
|
||||
"[TRAIN] container={} epoch={} iter={}/{}",
|
||||
containerName,
|
||||
epoch,
|
||||
iter,
|
||||
totalIter);
|
||||
|
||||
modelTrainJobCoreService.updateEpoch(containerName, epoch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
@@ -97,21 +124,6 @@ public class DockerTrainService {
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
// new Thread(
|
||||
// () -> {
|
||||
// try (BufferedReader br =
|
||||
// new BufferedReader(
|
||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
// String line;
|
||||
// while ((line = br.readLine()) != null) {
|
||||
// synchronized (log) {
|
||||
// log.append(line).append('\n');
|
||||
// }
|
||||
// }
|
||||
// } catch (Exception ignored) {
|
||||
// }
|
||||
// },
|
||||
// "train-log-" + containerName);
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
@@ -169,7 +181,7 @@ public class DockerTrainService {
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName + "-" + req.getUuid().substring(0, 8));
|
||||
c.add(containerName);
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
@@ -206,7 +218,7 @@ public class DockerTrainService {
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("/home/kcomu/data" + "/tmp:/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
@@ -264,12 +276,12 @@ public class DockerTrainService {
|
||||
|
||||
// ===== Augmentation =====
|
||||
addArg(c, "--rot-prob", req.getRotProb());
|
||||
addArg(c, "--rot-degree", req.getRotDegree());
|
||||
// addArg(c, "--rot-degree", req.getRotDegree()); // TODO AI 수정되면 주석 해제
|
||||
addArg(c, "--flip-prob", req.getFlipProb());
|
||||
addArg(c, "--exchange-prob", req.getExchangeProb());
|
||||
addArg(c, "--brightness-delta", req.getBrightnessDelta());
|
||||
addArg(c, "--contrast-range", req.getContrastRange());
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
// addArg(c, "--contrast-range", req.getContrastRange()); // TODO AI 수정되면 주석 해제
|
||||
// addArg(c, "--saturation-range", req.getSaturationRange()); // TODO AI 수정되면 주석 해제
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
addArg(c, "--resume-from", req.getResumeFrom());
|
||||
@@ -377,7 +389,7 @@ public class DockerTrainService {
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
c.add("--name");
|
||||
c.add(containerName + "=" + req.getUuid().substring(0, 8));
|
||||
c.add(containerName);
|
||||
c.add("--rm");
|
||||
|
||||
c.add("--gpus");
|
||||
|
||||
@@ -27,6 +27,10 @@ public class ModelTestMetricsJobService {
|
||||
@Value("${spring.profiles.active}")
|
||||
private String profile;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
@@ -51,7 +55,7 @@ public class ModelTestMetricsJobService {
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
|
||||
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ public class ModelTrainMetricsJobService {
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
|
||||
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
@@ -80,7 +80,7 @@ public class ModelTrainMetricsJobService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
|
||||
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
|
||||
@@ -226,9 +226,9 @@ public class TrainJobService {
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
updateReq.setRequestPath(pathUid);
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
Reference in New Issue
Block a user