임시파일생성 경로 수정 #66
@@ -5,7 +5,6 @@ import java.nio.file.*;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -13,8 +12,8 @@ import org.springframework.stereotype.Service;
|
|||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class TmpDatasetService {
|
public class TmpDatasetService {
|
||||||
|
|
||||||
@Value("${train.docker.requestDir}")
|
// @Value("${train.docker.requestDir}")
|
||||||
private String requestDir;
|
private String requestDir = "/home/kcomu/data";
|
||||||
|
|
||||||
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
|||||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
|
|||||||
job.setStatusCd("STOPPED");
|
job.setStatusCd("STOPPED");
|
||||||
job.setFinishedDttm(ZonedDateTime.now());
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public void updateEpoch(String containerName, Integer epoch) {
|
||||||
|
ModelTrainJobEntity job =
|
||||||
|
modelTrainJobRepository
|
||||||
|
.findByContainerName(containerName)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
||||||
|
|
||||||
|
job.setCurrentEpoch(epoch);
|
||||||
|
|
||||||
|
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
|
|||||||
int findMaxAttemptNo(Long modelId);
|
int findMaxAttemptNo(Long modelId);
|
||||||
|
|
||||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||||
|
|
||||||
|
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
|||||||
|
|
||||||
return Optional.ofNullable(job);
|
return Optional.ofNullable(job);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||||
|
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||||
|
|
||||||
|
ModelTrainJobEntity job =
|
||||||
|
queryFactory
|
||||||
|
.selectFrom(j)
|
||||||
|
.where(j.containerName.eq(containerName))
|
||||||
|
.orderBy(j.id.desc())
|
||||||
|
.fetchFirst();
|
||||||
|
|
||||||
|
return Optional.ofNullable(job);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||||
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Log4j2
|
@Log4j2
|
||||||
@Service
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
public class DockerTrainService {
|
public class DockerTrainService {
|
||||||
|
|
||||||
// 실행할 Docker 이미지명
|
// 실행할 Docker 이미지명
|
||||||
@@ -43,6 +47,8 @@ public class DockerTrainService {
|
|||||||
@Value("${train.docker.ipcHost:true}")
|
@Value("${train.docker.ipcHost:true}")
|
||||||
private boolean ipcHost;
|
private boolean ipcHost;
|
||||||
|
|
||||||
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
|
|
||||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||||
|
|
||||||
@@ -61,7 +67,11 @@ public class DockerTrainService {
|
|||||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||||
StringBuilder logBuilder = new StringBuilder();
|
StringBuilder logBuilder = new StringBuilder();
|
||||||
|
|
||||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
|
||||||
|
|
||||||
|
// 너무 잦은 업데이트 방지용
|
||||||
|
AtomicInteger lastEpoch = new AtomicInteger(0);
|
||||||
|
AtomicInteger lastIter = new AtomicInteger(0);
|
||||||
|
|
||||||
Thread logThread =
|
Thread logThread =
|
||||||
new Thread(
|
new Thread(
|
||||||
@@ -73,23 +83,40 @@ public class DockerTrainService {
|
|||||||
String line;
|
String line;
|
||||||
while ((line = br.readLine()) != null) {
|
while ((line = br.readLine()) != null) {
|
||||||
|
|
||||||
// 1) 로그 누적
|
|
||||||
synchronized (logBuilder) {
|
synchronized (logBuilder) {
|
||||||
logBuilder.append(line).append('\n');
|
logBuilder.append(line).append('\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) epoch 감지 + DB 업데이트
|
|
||||||
Matcher m = epochPattern.matcher(line);
|
Matcher m = epochPattern.matcher(line);
|
||||||
if (m.find()) {
|
if (m.find()) {
|
||||||
int currentEpoch = Integer.parseInt(m.group(1));
|
int epoch = Integer.parseInt(m.group(1));
|
||||||
int totalEpoch = Integer.parseInt(m.group(2));
|
int iter = Integer.parseInt(m.group(2));
|
||||||
|
int totalIter = Integer.parseInt(m.group(3));
|
||||||
|
|
||||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
|
||||||
|
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
|
||||||
|
|
||||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
// 쓰로틀링: 에폭 끝 or 10 iter마다
|
||||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
|
||||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
|
||||||
// currentEpoch, totalEpoch);
|
// 중복 방지
|
||||||
|
if (shouldUpdate) {
|
||||||
|
int prevEpoch = lastEpoch.get();
|
||||||
|
int prevIter = lastIter.get();
|
||||||
|
if (epoch != prevEpoch || iter != prevIter) {
|
||||||
|
lastEpoch.set(epoch);
|
||||||
|
lastIter.set(iter);
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"[TRAIN] container={} epoch={} iter={}/{}",
|
||||||
|
containerName,
|
||||||
|
epoch,
|
||||||
|
iter,
|
||||||
|
totalIter);
|
||||||
|
|
||||||
|
modelTrainJobCoreService.updateEpoch(containerName, epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -97,21 +124,6 @@ public class DockerTrainService {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"train-log-" + containerName);
|
"train-log-" + containerName);
|
||||||
// new Thread(
|
|
||||||
// () -> {
|
|
||||||
// try (BufferedReader br =
|
|
||||||
// new BufferedReader(
|
|
||||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
|
||||||
// String line;
|
|
||||||
// while ((line = br.readLine()) != null) {
|
|
||||||
// synchronized (log) {
|
|
||||||
// log.append(line).append('\n');
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } catch (Exception ignored) {
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// "train-log-" + containerName);
|
|
||||||
|
|
||||||
logThread.setDaemon(true);
|
logThread.setDaemon(true);
|
||||||
logThread.start();
|
logThread.start();
|
||||||
@@ -206,7 +218,7 @@ public class DockerTrainService {
|
|||||||
|
|
||||||
// 요청/결과 디렉토리 볼륨 마운트
|
// 요청/결과 디렉토리 볼륨 마운트
|
||||||
c.add("-v");
|
c.add("-v");
|
||||||
c.add(requestDir + "/tmp:/data");
|
c.add("/home/kcomu/data" + "/tmp:/data");
|
||||||
c.add("-v");
|
c.add("-v");
|
||||||
c.add(responseDir + ":/checkpoints");
|
c.add(responseDir + ":/checkpoints");
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user