Merge pull request '임시파일생성 경로 수정' (#66) from feat/training_260202 into develop
Reviewed-on: #66
This commit was merged in pull request #66.
This commit is contained in:
@@ -5,7 +5,6 @@ import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@@ -13,8 +12,8 @@ import org.springframework.stereotype.Service;
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
// @Value("${train.docker.requestDir}")
|
||||
private String requestDir = "/home/kcomu/data";
|
||||
|
||||
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateEpoch(String containerName, Integer epoch) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findByContainerName(containerName)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
||||
|
||||
job.setCurrentEpoch(epoch);
|
||||
|
||||
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||
}
|
||||
|
||||
@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory
|
||||
.selectFrom(j)
|
||||
.where(j.containerName.eq(containerName))
|
||||
.orderBy(j.id.desc())
|
||||
.fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Log4j2
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@@ -43,6 +47,8 @@ public class DockerTrainService {
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
@@ -61,7 +67,11 @@ public class DockerTrainService {
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder logBuilder = new StringBuilder();
|
||||
|
||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
|
||||
|
||||
// 너무 잦은 업데이트 방지용
|
||||
AtomicInteger lastEpoch = new AtomicInteger(0);
|
||||
AtomicInteger lastIter = new AtomicInteger(0);
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
@@ -73,23 +83,40 @@ public class DockerTrainService {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
|
||||
// 1) 로그 누적
|
||||
synchronized (logBuilder) {
|
||||
logBuilder.append(line).append('\n');
|
||||
}
|
||||
|
||||
// 2) epoch 감지 + DB 업데이트
|
||||
Matcher m = epochPattern.matcher(line);
|
||||
if (m.find()) {
|
||||
int currentEpoch = Integer.parseInt(m.group(1));
|
||||
int totalEpoch = Integer.parseInt(m.group(2));
|
||||
int epoch = Integer.parseInt(m.group(1));
|
||||
int iter = Integer.parseInt(m.group(2));
|
||||
int totalIter = Integer.parseInt(m.group(3));
|
||||
|
||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
|
||||
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
|
||||
|
||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||
// currentEpoch, totalEpoch);
|
||||
// 쓰로틀링: 에폭 끝 or 10 iter마다
|
||||
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
|
||||
|
||||
// 중복 방지
|
||||
if (shouldUpdate) {
|
||||
int prevEpoch = lastEpoch.get();
|
||||
int prevIter = lastIter.get();
|
||||
if (epoch != prevEpoch || iter != prevIter) {
|
||||
lastEpoch.set(epoch);
|
||||
lastIter.set(iter);
|
||||
|
||||
log.info(
|
||||
"[TRAIN] container={} epoch={} iter={}/{}",
|
||||
containerName,
|
||||
epoch,
|
||||
iter,
|
||||
totalIter);
|
||||
|
||||
modelTrainJobCoreService.updateEpoch(containerName, epoch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
@@ -97,21 +124,6 @@ public class DockerTrainService {
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
// new Thread(
|
||||
// () -> {
|
||||
// try (BufferedReader br =
|
||||
// new BufferedReader(
|
||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
// String line;
|
||||
// while ((line = br.readLine()) != null) {
|
||||
// synchronized (log) {
|
||||
// log.append(line).append('\n');
|
||||
// }
|
||||
// }
|
||||
// } catch (Exception ignored) {
|
||||
// }
|
||||
// },
|
||||
// "train-log-" + containerName);
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
@@ -206,7 +218,7 @@ public class DockerTrainService {
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + "/tmp:/data");
|
||||
c.add("/home/kcomu/data" + "/tmp:/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user