임시파일생성 경로 수정 #66

Merged
teddy merged 1 commits from feat/training_260202 into develop 2026-02-12 19:12:38 +09:00
5 changed files with 69 additions and 29 deletions
Showing only changes of commit 1fb10830b9 - Show all commits

View File

@@ -5,7 +5,6 @@ import java.nio.file.*;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Slf4j @Slf4j
@@ -13,8 +12,8 @@ import org.springframework.stereotype.Service;
@RequiredArgsConstructor @RequiredArgsConstructor
public class TmpDatasetService { public class TmpDatasetService {
@Value("${train.docker.requestDir}") // @Value("${train.docker.requestDir}")
private String requestDir; private String requestDir = "/home/kcomu/data";
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException { public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {

View File

@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto; import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
job.setStatusCd("STOPPED"); job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
} }
@Transactional
public void updateEpoch(String containerName, Integer epoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
job.setCurrentEpoch(epoch);
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
}
} }

View File

@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId); int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId); Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
} }

View File

@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
return Optional.ofNullable(job); return Optional.ofNullable(job);
} }
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
.selectFrom(j)
.where(j.containerName.eq(containerName))
.orderBy(j.id.desc())
.fetchFirst();
return Optional.ofNullable(job);
}
} }

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest; import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult; import com.kamco.cd.training.train.dto.TrainRunResult;
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2; import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Log4j2 @Log4j2
@Service @Service
@RequiredArgsConstructor
public class DockerTrainService { public class DockerTrainService {
// 실행할 Docker 이미지명 // 실행할 Docker 이미지명
@@ -43,6 +47,8 @@ public class DockerTrainService {
@Value("${train.docker.ipcHost:true}") @Value("${train.docker.ipcHost:true}")
private boolean ipcHost; private boolean ipcHost;
private final ModelTrainJobCoreService modelTrainJobCoreService;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception { public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
@@ -61,7 +67,11 @@ public class DockerTrainService {
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder logBuilder = new StringBuilder(); StringBuilder logBuilder = new StringBuilder();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b"); Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
// 너무 잦은 업데이트 방지용
AtomicInteger lastEpoch = new AtomicInteger(0);
AtomicInteger lastIter = new AtomicInteger(0);
Thread logThread = Thread logThread =
new Thread( new Thread(
@@ -73,23 +83,40 @@ public class DockerTrainService {
String line; String line;
while ((line = br.readLine()) != null) { while ((line = br.readLine()) != null) {
// 1) 로그 누적
synchronized (logBuilder) { synchronized (logBuilder) {
logBuilder.append(line).append('\n'); logBuilder.append(line).append('\n');
} }
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line); Matcher m = epochPattern.matcher(line);
if (m.find()) { if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1)); int epoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2)); int iter = Integer.parseInt(m.group(2));
int totalIter = Integer.parseInt(m.group(3));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch); // (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
// TODO 실행중인 에폭 저장 필요하면 만들어야함 // 쓰로틀링: 에폭 끝 or 10 iter마다
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..? boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch); // 중복 방지
if (shouldUpdate) {
int prevEpoch = lastEpoch.get();
int prevIter = lastIter.get();
if (epoch != prevEpoch || iter != prevIter) {
lastEpoch.set(epoch);
lastIter.set(iter);
log.info(
"[TRAIN] container={} epoch={} iter={}/{}",
containerName,
epoch,
iter,
totalIter);
modelTrainJobCoreService.updateEpoch(containerName, epoch);
}
}
} }
} }
} catch (Exception e) { } catch (Exception e) {
@@ -97,21 +124,6 @@ public class DockerTrainService {
} }
}, },
"train-log-" + containerName); "train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true); logThread.setDaemon(true);
logThread.start(); logThread.start();
@@ -206,7 +218,7 @@ public class DockerTrainService {
// 요청/결과 디렉토리 볼륨 마운트 // 요청/결과 디렉토리 볼륨 마운트
c.add("-v"); c.add("-v");
c.add(requestDir + "/tmp:/data"); c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v"); c.add("-v");
c.add(responseDir + ":/checkpoints"); c.add(responseDir + ":/checkpoints");