diff --git a/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java b/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java index 918bc77..3726f22 100644 --- a/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java +++ b/src/main/java/com/kamco/cd/training/model/service/TmpDatasetService.java @@ -5,7 +5,6 @@ import java.nio.file.*; import java.util.List; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @Slf4j @@ -13,8 +12,8 @@ import org.springframework.stereotype.Service; @RequiredArgsConstructor public class TmpDatasetService { - @Value("${train.docker.requestDir}") - private String requestDir; + // @Value("${train.docker.requestDir}") + private String requestDir = "/home/kcomu/data"; public String buildTmpDatasetSymlink(String uid, List datasetUids) throws IOException { diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java index 4a2ce5d..a20cfc9 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; import com.kamco.cd.training.train.dto.ModelTrainJobDto; import java.time.ZonedDateTime; import java.util.Map; +import java.util.Objects; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; @@ -104,4 +105,16 @@ public class ModelTrainJobCoreService { job.setStatusCd("STOPPED"); job.setFinishedDttm(ZonedDateTime.now()); } + + @Transactional + public void updateEpoch(String containerName, Integer epoch) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findByContainerName(containerName) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName)); + + job.setCurrentEpoch(epoch); + + if (Objects.equals(job.getTotalEpoch(), epoch)) {} + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java index ee79523..580c989 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java @@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom { int findMaxAttemptNo(Long modelId); Optional findLatestByModelId(Long modelId); + + Optional findByContainerName(String containerName); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java index cf74017..f7742ee 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java @@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto return Optional.ofNullable(job); } + + @Override + public Optional findByContainerName(String containerName) { + QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; + + ModelTrainJobEntity job = + queryFactory + .selectFrom(j) + .where(j.containerName.eq(containerName)) + .orderBy(j.id.desc()) + .fetchFirst(); + + return Optional.ofNullable(job); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index f56fcea..e99daf5 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -1,5 +1,6 @@ package com.kamco.cd.training.train.service; +import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.train.dto.EvalRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunResult; @@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; +import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @Log4j2 @Service +@RequiredArgsConstructor public class DockerTrainService { // 실행할 Docker 이미지명 @@ -43,6 +47,8 @@ public class DockerTrainService { @Value("${train.docker.ipcHost:true}") private boolean ipcHost; + private final ModelTrainJobCoreService modelTrainJobCoreService; + /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception { @@ -61,7 +67,11 @@ public class DockerTrainService { // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) StringBuilder logBuilder = new StringBuilder(); - Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b"); + Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]"); + + // 너무 잦은 업데이트 방지용 + AtomicInteger lastEpoch = new AtomicInteger(0); + AtomicInteger lastIter = new AtomicInteger(0); Thread logThread = new Thread( @@ -73,23 +83,40 @@ public class DockerTrainService { String line; while ((line = br.readLine()) != null) { - // 1) 로그 누적 synchronized (logBuilder) { logBuilder.append(line).append('\n'); } - // 2) epoch 감지 + DB 업데이트 Matcher m = epochPattern.matcher(line); if (m.find()) { - int currentEpoch = Integer.parseInt(m.group(1)); - int totalEpoch = Integer.parseInt(m.group(2)); + int epoch = Integer.parseInt(m.group(1)); + int iter = Integer.parseInt(m.group(2)); + int totalIter = Integer.parseInt(m.group(3)); - log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch); + // (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용 + int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0; - // TODO 실행중인 에폭 저장 필요하면 만들어야함 - // TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..? - // modelTrainMngCoreService.updateCurrentEpoch(modelId, - // currentEpoch, totalEpoch); + // 쓰로틀링: 에폭 끝 or 10 iter마다 + boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0); + + // 중복 방지 + if (shouldUpdate) { + int prevEpoch = lastEpoch.get(); + int prevIter = lastIter.get(); + if (epoch != prevEpoch || iter != prevIter) { + lastEpoch.set(epoch); + lastIter.set(iter); + + log.info( + "[TRAIN] container={} epoch={} iter={}/{}", + containerName, + epoch, + iter, + totalIter); + + modelTrainJobCoreService.updateEpoch(containerName, epoch); + } + } } } } catch (Exception e) { @@ -97,21 +124,6 @@ public class DockerTrainService { } }, "train-log-" + containerName); - // new Thread( - // () -> { - // try (BufferedReader br = - // new BufferedReader( - // new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { - // String line; - // while ((line = br.readLine()) != null) { - // synchronized (log) { - // log.append(line).append('\n'); - // } - // } - // } catch (Exception ignored) { - // } - // }, - // "train-log-" + containerName); logThread.setDaemon(true); logThread.start(); @@ -206,7 +218,7 @@ public class DockerTrainService { // 요청/결과 디렉토리 볼륨 마운트 c.add("-v"); - c.add(requestDir + "/tmp:/data"); + c.add("/home/kcomu/data" + "/tmp:/data"); c.add("-v"); c.add(responseDir + ":/checkpoints");