21 Commits

Author SHA1 Message Date
4d9c9a86b4 패키징 zip파일 만들기 커밋 2026-02-12 21:09:40 +09:00
5b682c1386 성공시 csv 파일 테이블에 저장 연결 2026-02-12 21:01:24 +09:00
8ada26448b 테스트 실행 경로 수정 2026-02-12 20:49:14 +09:00
5e0a771848 도커명 변경 2026-02-12 20:36:38 +09:00
e238f3ca88 도커 설정 추가 2026-02-12 20:24:57 +09:00
ad32ca18ca 임시파일생성 경로 수정 2026-02-12 20:03:03 +09:00
a10fccaae3 임시파일생성 경로 수정 2026-02-12 19:35:47 +09:00
c92426aefc Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 19:14:09 +09:00
6185a18a7c 모델목록 검색 조건 상태값 변경 2026-02-12 19:13:56 +09:00
1fb10830b9 임시파일생성 경로 수정 2026-02-12 19:11:51 +09:00
ae0d30e5da return 형식 수정 2026-02-12 18:55:42 +09:00
ed48f697a4 업로드 시 같은 uid로 업로드하지 못하게 조건 추가 2026-02-12 18:44:04 +09:00
6c865d26fd 임시파일생성 소프트링크에서 하드링크로 변경 2026-02-12 18:18:44 +09:00
16e156b5b4 문제되는 하이퍼파라미터 주석처리 2026-02-12 17:52:42 +09:00
6a939118ff 임시폴더생성 api 추가 2026-02-12 17:43:41 +09:00
0c0ae16c2b 임시폴더생성 api 추가 2026-02-12 17:23:34 +09:00
953f95aed6 임시폴더생성 api 추가 2026-02-12 17:14:26 +09:00
85633c8bab 임시폴더생성 api 추가 2026-02-12 17:03:21 +09:00
8b3940b446 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-12 16:59:44 +09:00
201cfefb6b 임시폴더생성 api 추가 2026-02-12 16:59:39 +09:00
9958b0999a csv 읽는 경로 수정하기, 변수명 수정 2026-02-12 16:58:28 +09:00
27 changed files with 469 additions and 70 deletions

View File

@@ -1,6 +1,11 @@
# Stage 1: Build stage (gradle build는 Jenkins에서 이미 수행)
FROM eclipse-temurin:21-jre-jammy
# docker CLI 설치 (컨테이너에서 호스트 Docker 제어용) 260212 추가
RUN apt-get update && \
apt-get install -y --no-install-recommends docker.io ca-certificates && \
rm -rf /var/lib/apt/lists/*
# 작업 디렉토리 설정
WORKDIR /app

View File

@@ -15,6 +15,7 @@ services:
- /mnt/nfs_share/model_output:/app/model-outputs
- /mnt/nfs_share/train_dataset:/app/train-dataset
- /home/kcomu/data:/home/kcomu/data
- /var/run/docker.sock:/var/run/docker.sock
networks:
- kamco-cds
restart: unless-stopped

View File

@@ -217,7 +217,7 @@ public class DatasetApiController {
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
@RequestBody @Valid DatasetDto.AddReq addReq) {
return ApiResponseDto.ok(datasetService.insertDataset(addReq));
return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
}
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")

View File

@@ -208,6 +208,13 @@ public class DatasetService {
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
try {
// 같은 uid 로 등록한 파일이 있는지 확인
Long existsCnt =
datasetCoreService.findDatasetByUidExistsCnt(addReq.getFileName().replace(".zip", ""));
if (existsCnt > 0) {
return new ResponseObj(ApiResponseCode.DUPLICATE_DATA, "이미 등록된 회차 데이터 파일입니다. 확인 부탁드립니다.");
}
// 압축 해제
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());

View File

@@ -60,7 +60,7 @@ public class ModelTrainMngDto {
}
}
public String getStep2StatusNAme() {
public String getStep2StatusName() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()

View File

@@ -13,7 +13,6 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -102,9 +101,9 @@ public class ModelTrainMngService {
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
String tmpUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
updateReq.setRequestPath(tmpUid);
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);

View File

@@ -7,7 +7,6 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Slf4j
@Service
@@ -17,50 +16,129 @@ public class TmpDatasetService {
@Value("${train.docker.requestDir}")
private String requestDir;
@Transactional(readOnly = true)
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
// 환경에 맞게 yml로 빼는 걸 추천
Path BASE = Paths.get(requestDir);
Path tmp = BASE.resolve("tmp").resolve(uid);
log.info("========== buildTmpDatasetHardlink START ==========");
log.info("uid={}", uid);
log.info("datasetUids={}", datasetUids);
log.info("requestDir(raw)={}", requestDir);
// mkdir -p "$TMP"/train/{input1,input2,label} ...
Path BASE = toPath(requestDir);
Path tmp = Path.of("/home/kcomu/data", "tmp", uid);
log.info("BASE={}", BASE);
log.info("BASE exists? {}", Files.isDirectory(BASE));
log.info("tmp={}", tmp);
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
// tmp 디렉토리 준비
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Files.createDirectories(tmp.resolve(type).resolve(part));
Path dir = tmp.resolve(type).resolve(part);
Files.createDirectories(dir);
log.info("createDirectories: {}", dir);
}
}
// 하드링크는 "같은 파일시스템"에서만 가능하므로 BASE/tmp가 같은 FS인지 미리 확인(권장)
try {
var baseStore = Files.getFileStore(BASE);
var tmpStore = Files.getFileStore(tmp.getParent()); // BASE/tmp
if (!baseStore.name().equals(tmpStore.name()) || !baseStore.type().equals(tmpStore.type())) {
throw new IOException(
"Hardlink requires same filesystem. baseStore="
+ baseStore.name()
+ "("
+ baseStore.type()
+ "), tmpStore="
+ tmpStore.name()
+ "("
+ tmpStore.type()
+ ")");
}
} catch (Exception e) {
// FileStore 비교가 환경마다 애매할 수 있어서, 여기서는 경고만 주고 실제 createLink에서 최종 판단하게 둘 수도 있음.
log.warn("FileStore check skipped/failed (will rely on createLink): {}", e.toString());
}
for (String id : datasetUids) {
Path srcRoot = BASE.resolve(id);
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
if (!Files.isDirectory(srcDir)) continue;
if (!Files.isDirectory(srcDir)) {
log.warn("SKIP (not directory): {}", srcDir);
noDir++;
continue;
}
scannedDirs++;
log.info("SCAN dir={}", srcDir);
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
for (Path f : stream) {
if (!Files.isRegularFile(f)) continue;
if (!Files.isRegularFile(f)) {
log.debug("skip non-regular file: {}", f);
continue;
}
regularFiles++;
String dstName = id + "__" + f.getFileName();
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
if (Files.exists(dst)) continue;
// dst가 남아있으면 삭제(심볼릭링크든 파일이든)
if (Files.exists(dst) || Files.isSymbolicLink(dst)) {
Files.delete(dst);
log.debug("deleted existing: {}", dst);
}
// ln -s "$f" "$dst" 와 동일
Files.createSymbolicLink(dst, f.toAbsolutePath());
try {
// 하드링크 생성 (dst가 새 파일로 생기지만 inode는 f와 동일)
Files.createLink(dst, f);
hardlinksMade++;
log.debug("created hardlink: {} => {}", dst, f);
} catch (IOException e) {
// 여기서 바로 실패시키면 “tmp는 만들었는데 내용은 0개” 같은 상태를 방지할 수 있음
log.error("FAILED create hardlink: {} => {}", dst, f, e);
throw e;
}
}
}
}
}
}
if (hardlinksMade == 0) {
throw new IOException(
"No hardlinks created. regularFiles="
+ regularFiles
+ ", scannedDirs="
+ scannedDirs
+ ", noDir="
+ noDir);
}
log.info("tmp dataset created: {}", tmp);
return tmp;
log.info(
"summary: scannedDirs={}, noDir={}, regularFiles={}, hardlinksMade={}",
scannedDirs,
noDir,
regularFiles,
hardlinksMade);
return uid;
}
private static Path toPath(String p) {
if (p.startsWith("~/")) {
return Paths.get(System.getProperty("user.home")).resolve(p.substring(2)).normalize();
}
return Paths.get(p).toAbsolutePath().normalize();
}
}

View File

@@ -246,4 +246,8 @@ public class DatasetCoreService
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
datasetObjRepository.insertDatasetValObj(objRegDto);
}
public Long findDatasetByUidExistsCnt(String uid) {
return datasetRepository.findDatasetByUidExistsCnt(uid);
}
}

View File

@@ -1,6 +1,8 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
import lombok.RequiredArgsConstructor;
@@ -26,4 +28,12 @@ public class ModelTestMetricsJobCoreService {
public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
}
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return modelTestMetricsJobRepository.getTestMetricPackingInfo(modelId);
}
public ModelTestFileName findModelTestFileNames(Long modelId) {
return modelTestMetricsJobRepository.findModelTestFileNames(modelId);
}
}

View File

@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
}
@Transactional
public void updateEpoch(String containerName, Integer epoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
job.setCurrentEpoch(epoch);
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
}
}

View File

@@ -24,4 +24,6 @@ public interface DatasetRepositoryCustom {
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
List<String> findDatasetUid(List<Long> datasetIds);
Long findDatasetByUidExistsCnt(String uid);
}

View File

@@ -247,4 +247,13 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
public List<String> findDatasetUid(List<Long> datasetIds) {
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
}
@Override
public Long findDatasetByUidExistsCnt(String uid) {
return queryFactory
.select(dataset.id.count())
.from(dataset)
.where(dataset.uid.eq(uid))
.fetchOne();
}
}

View File

@@ -16,10 +16,10 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
@Override
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
queryFactory
return queryFactory
.select(modelDatasetMappEntity)
.from(modelDatasetMappEntity)
.where(modelDatasetMappEntity.modelUid.eq(modelId));
return List.of();
.where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch();
}
}

View File

@@ -39,7 +39,11 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
BooleanBuilder builder = new BooleanBuilder();
if (req.getStatus() != null && !req.getStatus().isEmpty()) {
builder.and(modelMasterEntity.statusCd.eq(req.getStatus()));
builder.and(
modelMasterEntity
.step1State
.eq(req.getStatus())
.or(modelMasterEntity.step2State.eq(req.getStatus())));
}
if (req.getModelNo() != null && !req.getModelNo().isEmpty()) {

View File

@@ -1,5 +1,7 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
@@ -10,4 +12,8 @@ public interface ModelTestMetricsJobRepositoryCustom {
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs);
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
ModelTestFileName findModelTestFileNames(Long modelId);
}

View File

@@ -1,9 +1,14 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.Properties;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -42,7 +47,10 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step2EndDttm.isNotNull(),
@@ -67,4 +75,43 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelMetricJsonDto.class,
modelMasterEntity.modelNo,
modelMasterEntity.modelVer,
Projections.constructor(
Properties.class,
modelMetricsTestEntity.f1Score,
modelMetricsTestEntity.precisions,
modelMetricsTestEntity.recall,
modelMetricsTestEntity.iou,
modelMetricsTrainEntity.loss)))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.innerJoin(modelMetricsTrainEntity)
.on(
modelMetricsTestEntity.model.eq(modelMetricsTrainEntity.model),
modelMasterEntity.bestEpoch.eq(modelMetricsTrainEntity.epoch))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.fetchOne();
}
@Override
public ModelTestFileName findModelTestFileNames(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelTestFileName.class, modelMetricsTestEntity.model1, modelMasterEntity.modelVer))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.fetchOne();
}
}

View File

@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
}

View File

@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
return Optional.ofNullable(job);
}
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
.selectFrom(j)
.where(j.containerName.eq(containerName))
.orderBy(j.id.desc())
.fetchFirst();
return Optional.ofNullable(job);
}
}

View File

@@ -29,7 +29,10 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step1EndDttm.isNotNull(),

View File

@@ -1,6 +1,9 @@
package com.kamco.cd.training.train.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Properties;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
@@ -17,5 +20,40 @@ public class ModelTrainMetricsDto {
private Long modelId;
private String responsePath;
private UUID uuid;
}
@Getter
@AllArgsConstructor
public static class ModelMetricJsonDto {
@JsonProperty("cd_model_type")
private String cdModelType;
@JsonProperty("model_version")
private String modelVersion;
private Properties properties;
}
@Getter
@AllArgsConstructor
public static class Properties {
@JsonProperty("f1_score")
private Float f1Score;
private Float precision;
private Float recall;
private Float loss;
private Double iou;
}
@Getter
@AllArgsConstructor
public static class ModelTestFileName {
private String bestEpochFileName;
private String modelVersion;
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Log4j2
@Service
@RequiredArgsConstructor
public class DockerTrainService {
// 실행할 Docker 이미지명
@@ -43,6 +47,8 @@ public class DockerTrainService {
@Value("${train.docker.ipcHost:true}")
private boolean ipcHost;
private final ModelTrainJobCoreService modelTrainJobCoreService;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
@@ -56,12 +62,42 @@ public class DockerTrainService {
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
Process p = pb.start();
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder logBuilder = new StringBuilder();
Process p = pb.start();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
log.info("[TRAIN-BOOT] docker run started. container={}", containerName);
try {
log.info("[TRAIN-BOOT] pid={}", p.pid()); // Java 9+
} catch (Throwable ignore) {
}
try {
// 바로 죽었는지 100ms만 체크
if (p.waitFor(100, TimeUnit.MILLISECONDS)) {
int exit = p.exitValue();
String earlyLogs;
synchronized (logBuilder) {
earlyLogs = logBuilder.toString();
}
log.error(
"[TRAIN-BOOT] docker run exited immediately. container={} exit={}",
containerName,
exit);
log.error("[TRAIN-BOOT] early logs:\n{}", earlyLogs);
} else {
log.info("[TRAIN-BOOT] docker run is still running. container={}", containerName);
}
} catch (Exception e) {
log.warn("[TRAIN-BOOT] early-exit check failed: {}", e.toString(), e);
}
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
// 너무 잦은 업데이트 방지용
AtomicInteger lastEpoch = new AtomicInteger(0);
AtomicInteger lastIter = new AtomicInteger(0);
Thread logThread =
new Thread(
@@ -73,23 +109,40 @@ public class DockerTrainService {
String line;
while ((line = br.readLine()) != null) {
// 1) 로그 누적
synchronized (logBuilder) {
logBuilder.append(line).append('\n');
}
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line);
if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2));
int epoch = Integer.parseInt(m.group(1));
int iter = Integer.parseInt(m.group(2));
int totalIter = Integer.parseInt(m.group(3));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
// TODO 실행중인 에폭 저장 필요하면 만들어야함
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch);
// 쓰로틀링: 에폭 끝 or 10 iter마다
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
// 중복 방지
if (shouldUpdate) {
int prevEpoch = lastEpoch.get();
int prevIter = lastIter.get();
if (epoch != prevEpoch || iter != prevIter) {
lastEpoch.set(epoch);
lastIter.set(iter);
log.info(
"[TRAIN] container={} epoch={} iter={}/{}",
containerName,
epoch,
iter,
totalIter);
modelTrainJobCoreService.updateEpoch(containerName, epoch);
}
}
}
}
} catch (Exception e) {
@@ -97,21 +150,6 @@ public class DockerTrainService {
}
},
"train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true);
logThread.start();
@@ -169,7 +207,7 @@ public class DockerTrainService {
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName + "-" + req.getUuid().substring(0, 8));
c.add(containerName);
// 실행 종료 시 자동 삭제
c.add("--rm");
@@ -206,7 +244,7 @@ public class DockerTrainService {
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add(requestDir + ":/data");
c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
@@ -264,12 +302,12 @@ public class DockerTrainService {
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
// addArg(c, "--rot-degree", req.getRotDegree()); // TODO AI 수정되면 주석 해제
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
// addArg(c, "--contrast-range", req.getContrastRange()); // TODO AI 수정되면 주석 해제
// addArg(c, "--saturation-range", req.getSaturationRange()); // TODO AI 수정되면 주석 해제
addArg(c, "--hue-delta", req.getHueDelta());
addArg(c, "--resume-from", req.getResumeFrom());
@@ -377,7 +415,7 @@ public class DockerTrainService {
c.add("docker");
c.add("run");
c.add("--name");
c.add(containerName + "=" + req.getUuid().substring(0, 8));
c.add(containerName);
c.add("--rm");
c.add("--gpus");
@@ -386,7 +424,7 @@ public class DockerTrainService {
c.add("--shm-size=" + shmSize);
c.add("-v");
c.add(requestDir + ":/data");
c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");

View File

@@ -1,14 +1,26 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
@@ -27,6 +39,13 @@ public class ModelTestMetricsJobService {
@Value("${spring.profiles.active}")
private String profile;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
@Value("${file.pt-path}")
private String ptPathDir;
/**
* 실행중인 profile
*
@@ -36,8 +55,8 @@ public class ModelTestMetricsJobService {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() {
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() throws IOException {
// if (isLocalProfile()) {
// return;
// }
@@ -51,7 +70,7 @@ public class ModelTestMetricsJobService {
for (ResponsePathDto modelInfo : modelIds) {
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
@@ -96,7 +115,97 @@ public class ModelTestMetricsJobService {
throw new RuntimeException(e);
}
// 패키징할 파일 만들기
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir
+ "/"
+ modelInfo.getUuid()
+ "/"
+ jsonDto.getModelVersion()
+ ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
zipFiles(files, zipPath);
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
}
}
private void writeJsonFile(Object data, Path outputPath) throws IOException {
Path parent = outputPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
try (OutputStream os =
Files.newOutputStream(
outputPath, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING)) {
objectMapper.writeValue(os, data);
}
}
private void zipFiles(List<Path> files, Path zipPath) throws IOException {
Path parent = zipPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(zipPath))) {
for (Path file : files) {
ZipEntry entry = new ZipEntry(file.getFileName().toString());
zos.putNextEntry(entry);
Files.copy(file, zos);
zos.closeEntry();
}
}
}
}

View File

@@ -55,7 +55,7 @@ public class ModelTrainMetricsJobService {
for (ResponsePathDto modelInfo : modelIds) {
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
@@ -80,7 +80,7 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e);
}
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {

View File

@@ -226,9 +226,9 @@ public class TrainJobService {
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
updateReq.setRequestPath(pathUid);
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);

View File

@@ -22,6 +22,8 @@ public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@@ -90,8 +92,10 @@ public class TrainJobWorker {
if (isEval) {
modelTrainMngCoreService.markStep2Success(modelId);
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
} else {
modelTrainMngCoreService.markStep1Success(modelId);
modelTestMetricsJobService.findTestValidMetricCsvFiles();
}
} else {

View File

@@ -58,9 +58,12 @@ file:
dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
pt-FileName: yolov8_6th-6m.pt
train:
docker:
image: "kamco-cd-train:love_latest"
image: "kamco-cd-train:latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"

View File

@@ -44,9 +44,12 @@ file:
dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
pt-FileName: yolov8_6th-6m.pt
train:
docker:
image: "kamco-cd-train:love_latest"
image: "kamco-cd-train:latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"