상태변경 추가
This commit is contained in:
@@ -26,6 +26,10 @@ public class ModelTestMetricsJobCoreService {
|
|||||||
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) {
|
||||||
|
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelId(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||||
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
|
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,17 @@ public class ModelTrainJobCoreService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 실행 시작 처리 수정 */
|
||||||
|
@Transactional
|
||||||
|
public void updateJobStatus(Long jobId, String jobStatus) {
|
||||||
|
ModelTrainJobEntity job =
|
||||||
|
modelTrainJobRepository
|
||||||
|
.findById(jobId)
|
||||||
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
|
job.setStatusCd(jobStatus);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 성공 처리
|
* 성공 처리
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ public class ModelTrainMetricsJobCoreService {
|
|||||||
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) {
|
||||||
|
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelId(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||||
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
|
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ public interface ModelTestMetricsJobRepositoryCustom {
|
|||||||
|
|
||||||
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
||||||
|
|
||||||
|
ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId);
|
||||||
|
|
||||||
void insertModelMetricsTest(List<Object[]> batchArgs);
|
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||||
|
|
||||||
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
|
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
|
||||||
|
|||||||
@@ -63,6 +63,25 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
|||||||
.fetch();
|
.fetch();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) {
|
||||||
|
return queryFactory
|
||||||
|
.select(
|
||||||
|
Projections.constructor(
|
||||||
|
ResponsePathDto.class,
|
||||||
|
modelMasterEntity.id,
|
||||||
|
modelMasterEntity.responsePath,
|
||||||
|
modelMasterEntity.uuid))
|
||||||
|
.from(modelMasterEntity)
|
||||||
|
.where(
|
||||||
|
modelMasterEntity.id.eq(modelId),
|
||||||
|
modelMasterEntity
|
||||||
|
.step2MetricSaveYn
|
||||||
|
.isNull()
|
||||||
|
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
|
||||||
|
.fetchOne();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||||
// AS-IS
|
// AS-IS
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ public interface ModelTrainMetricsJobRepositoryCustom {
|
|||||||
|
|
||||||
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
|
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
|
||||||
|
|
||||||
|
ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId);
|
||||||
|
|
||||||
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
||||||
|
|
||||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||||
|
|||||||
@@ -44,6 +44,25 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
|||||||
.fetch();
|
.fetch();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) {
|
||||||
|
return queryFactory
|
||||||
|
.select(
|
||||||
|
Projections.constructor(
|
||||||
|
ResponsePathDto.class,
|
||||||
|
modelMasterEntity.id,
|
||||||
|
modelMasterEntity.responsePath,
|
||||||
|
modelMasterEntity.uuid))
|
||||||
|
.from(modelMasterEntity)
|
||||||
|
.where(
|
||||||
|
modelMasterEntity.id.eq(modelId),
|
||||||
|
modelMasterEntity
|
||||||
|
.step1MetricSaveYn
|
||||||
|
.isNull()
|
||||||
|
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
|
||||||
|
.fetchOne();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||||
String sql =
|
String sql =
|
||||||
|
|||||||
@@ -213,4 +213,27 @@ public class TrainApiController {
|
|||||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
|
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "학습 상태 확인", description = "학습 상태 확인")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "학습 상태 변경",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping(path = "/status/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE)
|
||||||
|
public ApiResponseDto<String> status(
|
||||||
|
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
trainJobService.status(uuid, modelId);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.kamco.cd.training.train.dto;
|
||||||
|
|
||||||
|
public record DockerInspectState(boolean exists, boolean running, Integer exitCode, String status) {
|
||||||
|
|
||||||
|
public static DockerInspectState missing() {
|
||||||
|
return new DockerInspectState(false, false, null, "missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.kamco.cd.training.train.dto;
|
||||||
|
|
||||||
|
public class OutputResult {
|
||||||
|
|
||||||
|
private final boolean completed;
|
||||||
|
private final String reason;
|
||||||
|
|
||||||
|
public OutputResult(boolean completed, String reason) {
|
||||||
|
this.completed = completed;
|
||||||
|
this.reason = reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean completed() {
|
||||||
|
return completed;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String reason() {
|
||||||
|
return reason;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,25 +1,16 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.DockerInspectState;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import java.io.BufferedReader;
|
import com.kamco.cd.training.train.dto.OutputResult;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStreamReader;
|
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
import java.nio.file.DirectoryStream;
|
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.stream.Stream;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||||
import org.springframework.context.annotation.Profile;
|
import org.springframework.context.annotation.Profile;
|
||||||
import org.springframework.context.event.EventListener;
|
import org.springframework.context.event.EventListener;
|
||||||
@@ -44,14 +35,7 @@ public class JobRecoveryOnStartupService {
|
|||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
|
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
|
||||||
|
private final TrainUtilService trainUtilService;
|
||||||
/**
|
|
||||||
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
|
|
||||||
*
|
|
||||||
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
|
|
||||||
*/
|
|
||||||
@Value("${train.docker.response_dir}")
|
|
||||||
private String responseDir;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
|
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
|
||||||
@@ -77,7 +61,7 @@ public class JobRecoveryOnStartupService {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// 2-1) docker inspect로 컨테이너 상태 조회
|
// 2-1) docker inspect로 컨테이너 상태 조회
|
||||||
DockerInspectState state = inspectContainer(containerName);
|
DockerInspectState state = trainUtilService.inspectContainer(containerName);
|
||||||
|
|
||||||
// 3) 컨테이너가 "없음"
|
// 3) 컨테이너가 "없음"
|
||||||
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
|
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
|
||||||
@@ -88,7 +72,7 @@ public class JobRecoveryOnStartupService {
|
|||||||
containerName);
|
containerName);
|
||||||
|
|
||||||
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
|
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
|
||||||
OutputResult out = probeOutputs(job);
|
OutputResult out = trainUtilService.probeOutputs(job);
|
||||||
|
|
||||||
// 3-2) 산출물이 충분하면 성공 처리
|
// 3-2) 산출물이 충분하면 성공 처리
|
||||||
if (out.completed()) {
|
if (out.completed()) {
|
||||||
@@ -109,11 +93,9 @@ public class JobRecoveryOnStartupService {
|
|||||||
job.getId(),
|
job.getId(),
|
||||||
out.reason());
|
out.reason());
|
||||||
|
|
||||||
Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId());
|
|
||||||
|
|
||||||
// PAUSED/STOP
|
// PAUSED/STOP
|
||||||
modelTrainJobCoreService.markPaused(
|
modelTrainJobCoreService.markPaused(
|
||||||
job.getId(), modelId, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
|
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
|
||||||
|
|
||||||
// 모델도 에러가 아니라 STOP으로
|
// 모델도 에러가 아니라 STOP으로
|
||||||
markStepStopByJobType(
|
markStepStopByJobType(
|
||||||
@@ -152,7 +134,7 @@ public class JobRecoveryOnStartupService {
|
|||||||
// ============================================================
|
// ============================================================
|
||||||
// 2) kill 후 실제로 죽었는지 확인
|
// 2) kill 후 실제로 죽었는지 확인
|
||||||
// ============================================================
|
// ============================================================
|
||||||
DockerInspectState after = inspectContainer(containerName);
|
DockerInspectState after = trainUtilService.inspectContainer(containerName);
|
||||||
if (after.exists() && after.running()) {
|
if (after.exists() && after.running()) {
|
||||||
throw new IOException("docker kill returned 0 but container still running");
|
throw new IOException("docker kill returned 0 but container still running");
|
||||||
}
|
}
|
||||||
@@ -162,10 +144,8 @@ public class JobRecoveryOnStartupService {
|
|||||||
// ============================================================
|
// ============================================================
|
||||||
// 3) job 상태를 PAUSED로 변경 (서버 재기동으로 강제 중단)
|
// 3) job 상태를 PAUSED로 변경 (서버 재기동으로 강제 중단)
|
||||||
// ============================================================
|
// ============================================================
|
||||||
Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId());
|
|
||||||
|
|
||||||
modelTrainJobCoreService.markPaused(
|
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_KILLED_ON_SERVER_RESTART");
|
||||||
job.getId(), modelId, "AUTO_KILLED_ON_SERVER_RESTART");
|
|
||||||
|
|
||||||
log.info("job = {}", job);
|
log.info("job = {}", job);
|
||||||
markStepStopByJobType(job, "AUTO_KILLED_ON_SERVER_RESTART");
|
markStepStopByJobType(job, "AUTO_KILLED_ON_SERVER_RESTART");
|
||||||
@@ -264,301 +244,4 @@ public class JobRecoveryOnStartupService {
|
|||||||
modelTrainMngCoreService.markError(job.getModelId(), msg);
|
modelTrainMngCoreService.markError(job.getModelId(), msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
|
|
||||||
*
|
|
||||||
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
|
|
||||||
*
|
|
||||||
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
|
|
||||||
*
|
|
||||||
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
|
|
||||||
* 타임아웃을 둠
|
|
||||||
*/
|
|
||||||
private DockerInspectState inspectContainer(String containerName)
|
|
||||||
throws IOException, InterruptedException {
|
|
||||||
|
|
||||||
ProcessBuilder pb =
|
|
||||||
new ProcessBuilder(
|
|
||||||
"docker",
|
|
||||||
"inspect",
|
|
||||||
"-f",
|
|
||||||
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
|
|
||||||
containerName);
|
|
||||||
|
|
||||||
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
|
|
||||||
pb.redirectErrorStream(true);
|
|
||||||
|
|
||||||
Process p = pb.start();
|
|
||||||
|
|
||||||
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
|
|
||||||
String output;
|
|
||||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
|
|
||||||
output = br.readLine();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
|
|
||||||
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
|
|
||||||
if (!finished) {
|
|
||||||
p.destroyForcibly();
|
|
||||||
throw new IOException("docker inspect timeout");
|
|
||||||
}
|
|
||||||
|
|
||||||
// docker inspect 자체의 프로세스 exit code
|
|
||||||
int code = p.exitValue();
|
|
||||||
|
|
||||||
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
|
|
||||||
if (code != 0 || output == null || output.isBlank()) {
|
|
||||||
return DockerInspectState.missing();
|
|
||||||
}
|
|
||||||
|
|
||||||
// "status running exitCode" 형태로 split
|
|
||||||
String[] parts = output.trim().split("\\s+");
|
|
||||||
|
|
||||||
// status: running/exited/dead 등
|
|
||||||
String status = parts.length > 0 ? parts[0] : "unknown";
|
|
||||||
|
|
||||||
// running: true/false
|
|
||||||
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
|
|
||||||
|
|
||||||
// exitCode: 정수 파싱(파싱 실패하면 null)
|
|
||||||
Integer exitCode = null;
|
|
||||||
if (parts.length > 2) {
|
|
||||||
try {
|
|
||||||
exitCode = Integer.parseInt(parts[2]);
|
|
||||||
} catch (Exception ignore) {
|
|
||||||
// ignore
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new DockerInspectState(true, running, exitCode, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* docker inspect 결과를 담는 레코드.
|
|
||||||
*
|
|
||||||
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
|
|
||||||
*/
|
|
||||||
private record DockerInspectState(
|
|
||||||
boolean exists, boolean running, Integer exitCode, String status) {
|
|
||||||
static DockerInspectState missing() {
|
|
||||||
return new DockerInspectState(false, false, null, "missing");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================================
|
|
||||||
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
|
|
||||||
// ============================================================================================
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
|
|
||||||
*
|
|
||||||
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
|
|
||||||
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
|
|
||||||
*
|
|
||||||
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
|
|
||||||
* 실패로 오판할 수 있음.
|
|
||||||
*/
|
|
||||||
private OutputResult probeOutputs(ModelTrainJobDto job) {
|
|
||||||
try {
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"[RECOVERY] probeOutputs start. jobId={}, modelId={}", job.getId(), job.getModelId());
|
|
||||||
|
|
||||||
// 1) 출력 디렉토리 확인
|
|
||||||
Path outDir = resolveOutputDir(job);
|
|
||||||
if (outDir == null || !Files.isDirectory(outDir)) {
|
|
||||||
log.warn("[RECOVERY] output directory missing. jobId={}, path={}", job.getId(), outDir);
|
|
||||||
return new OutputResult(false, "output-dir-missing");
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("[RECOVERY] output directory found. jobId={}, path={}", job.getId(), outDir);
|
|
||||||
|
|
||||||
// 2) totalEpoch 확인
|
|
||||||
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
|
|
||||||
if (totalEpoch == null || totalEpoch <= 0) {
|
|
||||||
log.warn(
|
|
||||||
"[RECOVERY] totalEpoch missing or invalid. jobId={}, totalEpoch={}",
|
|
||||||
job.getId(),
|
|
||||||
totalEpoch);
|
|
||||||
return new OutputResult(false, "total-epoch-missing");
|
|
||||||
}
|
|
||||||
|
|
||||||
Integer valInterval = extractValInterval(job).orElse(null);
|
|
||||||
if (valInterval == null || valInterval <= 0) {
|
|
||||||
log.warn(
|
|
||||||
"[RECOVERY] valInterval missing or invalid. jobId={}, valInterval={}",
|
|
||||||
job.getId(),
|
|
||||||
valInterval);
|
|
||||||
return new OutputResult(false, "val-interval-missing");
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"[RECOVERY] totalEpoch={}. valInterval={}. jobId={}",
|
|
||||||
totalEpoch,
|
|
||||||
valInterval,
|
|
||||||
job.getId());
|
|
||||||
|
|
||||||
// 3) val.csv 존재 확인
|
|
||||||
Path valCsv = outDir.resolve("val.csv");
|
|
||||||
if (!Files.exists(valCsv)) {
|
|
||||||
log.warn("[RECOVERY] val.csv missing. jobId={}, path={}", job.getId(), valCsv);
|
|
||||||
return new OutputResult(false, "val.csv-missing");
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4) val.csv 라인 수 확인
|
|
||||||
long lines = countNonHeaderLines(valCsv);
|
|
||||||
|
|
||||||
// expected = 실제 val 실행 횟수
|
|
||||||
int expectedLines = totalEpoch / valInterval;
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}",
|
|
||||||
job.getId(),
|
|
||||||
lines,
|
|
||||||
expectedLines);
|
|
||||||
|
|
||||||
// 5) 완료 판정
|
|
||||||
if (lines >= expectedLines) {
|
|
||||||
log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId());
|
|
||||||
return new OutputResult(true, "ok");
|
|
||||||
}
|
|
||||||
|
|
||||||
log.warn(
|
|
||||||
"[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}",
|
|
||||||
job.getId(),
|
|
||||||
lines,
|
|
||||||
expectedLines);
|
|
||||||
|
|
||||||
return new OutputResult(
|
|
||||||
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
|
|
||||||
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
|
|
||||||
|
|
||||||
return new OutputResult(false, "probe-error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
|
|
||||||
*
|
|
||||||
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
|
|
||||||
*
|
|
||||||
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
|
|
||||||
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
|
|
||||||
*
|
|
||||||
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
|
|
||||||
* 형태를 1순위로 두세요.
|
|
||||||
*/
|
|
||||||
private Path resolveOutputDir(ModelTrainJobDto job) {
|
|
||||||
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
|
|
||||||
|
|
||||||
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
|
|
||||||
|
|
||||||
return Files.isDirectory(base) ? base : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* paramsJson에서 total_epoch 값을 추출합니다.
|
|
||||||
*
|
|
||||||
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
|
|
||||||
*
|
|
||||||
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
|
|
||||||
*/
|
|
||||||
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
|
|
||||||
Map<String, Object> params = job.getParamsJson();
|
|
||||||
if (params == null) return Optional.empty();
|
|
||||||
|
|
||||||
Object v = params.get("total_epoch");
|
|
||||||
if (v == null) v = params.get("totalEpoch");
|
|
||||||
if (v == null) return Optional.empty();
|
|
||||||
|
|
||||||
try {
|
|
||||||
return Optional.of(Integer.parseInt(String.valueOf(v)));
|
|
||||||
} catch (Exception ignore) {
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
|
|
||||||
*
|
|
||||||
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
|
|
||||||
*
|
|
||||||
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
|
|
||||||
*/
|
|
||||||
private long countNonHeaderLines(Path csv) throws IOException {
|
|
||||||
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
|
|
||||||
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
|
|
||||||
*
|
|
||||||
* <p>예: - "*.pth" - "best*.pth"
|
|
||||||
*/
|
|
||||||
private long countFilesByGlob(Path dir, String glob) throws IOException {
|
|
||||||
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
|
|
||||||
long cnt = 0;
|
|
||||||
for (Path p : ds) {
|
|
||||||
if (Files.isRegularFile(p)) cnt++;
|
|
||||||
}
|
|
||||||
return cnt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
|
|
||||||
private boolean existsByGlob(Path dir, String glob) throws IOException {
|
|
||||||
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
|
|
||||||
return ds.iterator().hasNext();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================================
|
|
||||||
// probeOutputs() 결과 객체
|
|
||||||
// ============================================================================================
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
|
|
||||||
*
|
|
||||||
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
|
|
||||||
*
|
|
||||||
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
|
|
||||||
*/
|
|
||||||
private static final class OutputResult {
|
|
||||||
|
|
||||||
private final boolean completed;
|
|
||||||
private final String reason;
|
|
||||||
|
|
||||||
private OutputResult(boolean completed, String reason) {
|
|
||||||
this.completed = completed;
|
|
||||||
this.reason = reason;
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean completed() {
|
|
||||||
return completed;
|
|
||||||
}
|
|
||||||
|
|
||||||
String reason() {
|
|
||||||
return reason;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** paramsJson에서 valInterval 추출 */
|
|
||||||
private Optional<Integer> extractValInterval(ModelTrainJobDto job) {
|
|
||||||
Map<String, Object> params = job.getParamsJson();
|
|
||||||
if (params == null) return Optional.empty();
|
|
||||||
|
|
||||||
Object v = params.get("valInterval");
|
|
||||||
if (v == null) return Optional.empty();
|
|
||||||
|
|
||||||
try {
|
|
||||||
return Optional.of(Integer.parseInt(String.valueOf(v)));
|
|
||||||
} catch (Exception ignore) {
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,119 +59,133 @@ public class ModelTestMetricsJobService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (ResponsePathDto modelInfo : modelIds) {
|
for (ResponsePathDto modelInfo : modelIds) {
|
||||||
|
createFile(modelInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
|
/** 단건 결과 csv 파일 정보 등록 */
|
||||||
try (BufferedReader reader =
|
public void testValidMetricCsvFiles(Long modelId) {
|
||||||
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
ResponsePathDto model = modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelId(modelId);
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
if (model == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
createFile(model);
|
||||||
|
}
|
||||||
|
|
||||||
String model = record.get("model");
|
/**
|
||||||
long TP = Long.parseLong(record.get("TP"));
|
* 베스트 에폭 zip파일 생성, 테스트결과 db등록
|
||||||
long FP = Long.parseLong(record.get("FP"));
|
*
|
||||||
long FN = Long.parseLong(record.get("FN"));
|
* @param modelInfo
|
||||||
float precision = Float.parseFloat(record.get("precision"));
|
*/
|
||||||
float recall = Float.parseFloat(record.get("recall"));
|
private void createFile(ResponsePathDto modelInfo) {
|
||||||
float f1_score = Float.parseFloat(record.get("f1_score"));
|
|
||||||
float accuracy = Float.parseFloat(record.get("accuracy"));
|
|
||||||
float iou = Float.parseFloat(record.get("iou"));
|
|
||||||
long detection_count = Long.parseLong(record.get("detection_count"));
|
|
||||||
long gt_count = Long.parseLong(record.get("gt_count"));
|
|
||||||
|
|
||||||
batchArgs.add(
|
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
|
||||||
new Object[] {
|
try (BufferedReader reader =
|
||||||
modelInfo.getModelId(),
|
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||||
model,
|
|
||||||
TP,
|
|
||||||
FP,
|
|
||||||
FN,
|
|
||||||
precision,
|
|
||||||
recall,
|
|
||||||
f1_score,
|
|
||||||
accuracy,
|
|
||||||
iou,
|
|
||||||
detection_count,
|
|
||||||
gt_count
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
|
|
||||||
// test.csv 파일 읽어서 저장한 여부로만 사용하기
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
|
||||||
modelInfo.getModelId(), "step2");
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
for (CSVRecord record : parser) {
|
||||||
throw new RuntimeException(e);
|
|
||||||
|
String model = record.get("model");
|
||||||
|
long TP = Long.parseLong(record.get("TP"));
|
||||||
|
long FP = Long.parseLong(record.get("FP"));
|
||||||
|
long FN = Long.parseLong(record.get("FN"));
|
||||||
|
float precision = Float.parseFloat(record.get("precision"));
|
||||||
|
float recall = Float.parseFloat(record.get("recall"));
|
||||||
|
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||||
|
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||||
|
float iou = Float.parseFloat(record.get("iou"));
|
||||||
|
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||||
|
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||||
|
|
||||||
|
batchArgs.add(
|
||||||
|
new Object[] {
|
||||||
|
modelInfo.getModelId(),
|
||||||
|
model,
|
||||||
|
TP,
|
||||||
|
FP,
|
||||||
|
FN,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
f1_score,
|
||||||
|
accuracy,
|
||||||
|
iou,
|
||||||
|
detection_count,
|
||||||
|
gt_count
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 패키징할 파일 만들기
|
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||||
modelTestMetricsJobCoreService.updatePackingStart(
|
|
||||||
modelInfo.getModelId(), ZonedDateTime.now());
|
|
||||||
|
|
||||||
ModelMetricJsonDto jsonDto =
|
// test.csv 파일 읽어서 저장한 여부로만 사용하기
|
||||||
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
|
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
|
||||||
try {
|
|
||||||
writeJsonFile(
|
|
||||||
jsonDto,
|
|
||||||
Paths.get(
|
|
||||||
responseDir
|
|
||||||
+ "/"
|
|
||||||
+ modelInfo.getUuid()
|
|
||||||
+ "/"
|
|
||||||
+ jsonDto.getModelVersion()
|
|
||||||
+ ".json"));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
ModelTestFileName fileInfo =
|
// 패키징할 파일 만들기
|
||||||
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
|
modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now());
|
||||||
|
|
||||||
Path zipPath =
|
ModelMetricJsonDto jsonDto =
|
||||||
|
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
|
||||||
|
try {
|
||||||
|
writeJsonFile(
|
||||||
|
jsonDto,
|
||||||
Paths.get(
|
Paths.get(
|
||||||
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
|
responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json"));
|
||||||
Set<String> targetNames =
|
} catch (IOException e) {
|
||||||
Set.of(
|
throw new RuntimeException(e);
|
||||||
"model_config.py",
|
}
|
||||||
fileInfo.getBestEpochFileName() + ".pth",
|
|
||||||
fileInfo.getModelVersion() + ".json");
|
|
||||||
|
|
||||||
List<Path> files = new ArrayList<>();
|
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||||
try (Stream<Path> s = Files.list(responsePath)) {
|
|
||||||
files.addAll(
|
|
||||||
s.filter(Files::isRegularFile)
|
|
||||||
.filter(p -> targetNames.contains(p.getFileName().toString()))
|
|
||||||
.collect(Collectors.toList()));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
|
ModelTestFileName fileInfo =
|
||||||
files.addAll(
|
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
|
||||||
s.filter(Files::isRegularFile)
|
|
||||||
.limit(1) // yolov8_6th-6m.pt 파일 1개만
|
|
||||||
.collect(Collectors.toList()));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
Path zipPath =
|
||||||
zipFiles(files, zipPath);
|
Paths.get(
|
||||||
|
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
|
||||||
|
Set<String> targetNames =
|
||||||
|
Set.of(
|
||||||
|
"model_config.py",
|
||||||
|
fileInfo.getBestEpochFileName() + ".pth",
|
||||||
|
fileInfo.getModelVersion() + ".json");
|
||||||
|
|
||||||
modelTestMetricsJobCoreService.updatePackingEnd(
|
List<Path> files = new ArrayList<>();
|
||||||
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
|
try (Stream<Path> s = Files.list(responsePath)) {
|
||||||
} catch (IOException e) {
|
files.addAll(
|
||||||
modelTestMetricsJobCoreService.updatePackingEnd(
|
s.filter(Files::isRegularFile)
|
||||||
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
|
.filter(p -> targetNames.contains(p.getFileName().toString()))
|
||||||
throw new RuntimeException(e);
|
.collect(Collectors.toList()));
|
||||||
}
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
|
||||||
|
files.addAll(
|
||||||
|
s.filter(Files::isRegularFile)
|
||||||
|
.limit(1) // yolov8_6th-6m.pt 파일 1개만
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
zipFiles(files, zipPath);
|
||||||
|
|
||||||
|
modelTestMetricsJobCoreService.updatePackingEnd(
|
||||||
|
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
|
||||||
|
} catch (IOException e) {
|
||||||
|
modelTestMetricsJobCoreService.updatePackingEnd(
|
||||||
|
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,115 +48,135 @@ public class ModelTrainMetricsJobService {
|
|||||||
|
|
||||||
for (ResponsePathDto modelInfo : modelIds) {
|
for (ResponsePathDto modelInfo : modelIds) {
|
||||||
|
|
||||||
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
|
createFile(modelInfo);
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
|
||||||
long iteration = Long.parseLong(record.get("Iteration"));
|
|
||||||
double Loss = Double.parseDouble(record.get("Loss"));
|
|
||||||
double LR = Double.parseDouble(record.get("LR"));
|
|
||||||
float time = Float.parseFloat(record.get("Time"));
|
|
||||||
|
|
||||||
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
|
|
||||||
}
|
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
|
|
||||||
try (BufferedReader reader =
|
|
||||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
|
||||||
|
|
||||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
|
||||||
|
|
||||||
List<Object[]> batchArgs = new ArrayList<>();
|
|
||||||
|
|
||||||
for (CSVRecord record : parser) {
|
|
||||||
|
|
||||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
|
||||||
|
|
||||||
Float aAcc = parseFloatSafe(record.get("aAcc"));
|
|
||||||
Float mFscore = parseFloatSafe(record.get("mFscore"));
|
|
||||||
Float mPrecision = parseFloatSafe(record.get("mPrecision"));
|
|
||||||
Float mRecall = parseFloatSafe(record.get("mRecall"));
|
|
||||||
Float mIoU = parseFloatSafe(record.get("mIoU"));
|
|
||||||
Float mAcc = parseFloatSafe(record.get("mAcc"));
|
|
||||||
|
|
||||||
Float changed_fscore = parseFloatSafe(record.get("changed_fscore"));
|
|
||||||
Float changed_precision = parseFloatSafe(record.get("changed_precision"));
|
|
||||||
Float changed_recall = parseFloatSafe(record.get("changed_recall"));
|
|
||||||
|
|
||||||
Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore"));
|
|
||||||
Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision"));
|
|
||||||
Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall"));
|
|
||||||
|
|
||||||
batchArgs.add(
|
|
||||||
new Object[] {
|
|
||||||
modelInfo.getModelId(),
|
|
||||||
epoch,
|
|
||||||
aAcc,
|
|
||||||
mFscore,
|
|
||||||
mPrecision,
|
|
||||||
mRecall,
|
|
||||||
mIoU,
|
|
||||||
mAcc,
|
|
||||||
changed_fscore,
|
|
||||||
changed_precision,
|
|
||||||
changed_recall,
|
|
||||||
unchanged_fscore,
|
|
||||||
unchanged_precision,
|
|
||||||
unchanged_recall
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
|
||||||
Integer epoch = null;
|
|
||||||
boolean exists;
|
|
||||||
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
|
|
||||||
|
|
||||||
try (Stream<Path> s = Files.list(responsePath)) {
|
|
||||||
epoch =
|
|
||||||
s.filter(Files::isRegularFile)
|
|
||||||
.map(
|
|
||||||
p -> {
|
|
||||||
Matcher matcher = pattern.matcher(p.getFileName().toString());
|
|
||||||
if (matcher.matches()) {
|
|
||||||
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
})
|
|
||||||
.filter(Objects::nonNull)
|
|
||||||
.findFirst()
|
|
||||||
.orElse(null);
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
|
|
||||||
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
|
|
||||||
|
|
||||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
|
||||||
modelInfo.getModelId(), "step1");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 단건 결과 csv 파일 정보 등록 */
|
||||||
|
public void trainValidMetricCsvFile(Long modelId) {
|
||||||
|
|
||||||
|
ResponsePathDto modelInfo =
|
||||||
|
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelId(modelId);
|
||||||
|
|
||||||
|
if (modelInfo == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
createFile(modelInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 csv 파일 db 등록
|
||||||
|
*
|
||||||
|
* @param modelInfo
|
||||||
|
*/
|
||||||
|
private void createFile(ResponsePathDto modelInfo) {
|
||||||
|
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
|
||||||
|
try (BufferedReader reader =
|
||||||
|
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
|
|
||||||
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
|
|
||||||
|
for (CSVRecord record : parser) {
|
||||||
|
|
||||||
|
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||||
|
long iteration = Long.parseLong(record.get("Iteration"));
|
||||||
|
double Loss = Double.parseDouble(record.get("Loss"));
|
||||||
|
double LR = Double.parseDouble(record.get("LR"));
|
||||||
|
float time = Float.parseFloat(record.get("Time"));
|
||||||
|
|
||||||
|
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
|
||||||
|
}
|
||||||
|
|
||||||
|
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
|
||||||
|
try (BufferedReader reader =
|
||||||
|
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||||
|
|
||||||
|
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||||
|
|
||||||
|
List<Object[]> batchArgs = new ArrayList<>();
|
||||||
|
|
||||||
|
for (CSVRecord record : parser) {
|
||||||
|
|
||||||
|
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||||
|
|
||||||
|
Float aAcc = parseFloatSafe(record.get("aAcc"));
|
||||||
|
Float mFscore = parseFloatSafe(record.get("mFscore"));
|
||||||
|
Float mPrecision = parseFloatSafe(record.get("mPrecision"));
|
||||||
|
Float mRecall = parseFloatSafe(record.get("mRecall"));
|
||||||
|
Float mIoU = parseFloatSafe(record.get("mIoU"));
|
||||||
|
Float mAcc = parseFloatSafe(record.get("mAcc"));
|
||||||
|
|
||||||
|
Float changed_fscore = parseFloatSafe(record.get("changed_fscore"));
|
||||||
|
Float changed_precision = parseFloatSafe(record.get("changed_precision"));
|
||||||
|
Float changed_recall = parseFloatSafe(record.get("changed_recall"));
|
||||||
|
|
||||||
|
Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore"));
|
||||||
|
Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision"));
|
||||||
|
Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall"));
|
||||||
|
|
||||||
|
batchArgs.add(
|
||||||
|
new Object[] {
|
||||||
|
modelInfo.getModelId(),
|
||||||
|
epoch,
|
||||||
|
aAcc,
|
||||||
|
mFscore,
|
||||||
|
mPrecision,
|
||||||
|
mRecall,
|
||||||
|
mIoU,
|
||||||
|
mAcc,
|
||||||
|
changed_fscore,
|
||||||
|
changed_precision,
|
||||||
|
changed_recall,
|
||||||
|
unchanged_fscore,
|
||||||
|
unchanged_precision,
|
||||||
|
unchanged_recall
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||||
|
Integer epoch = null;
|
||||||
|
boolean exists;
|
||||||
|
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
|
||||||
|
|
||||||
|
try (Stream<Path> s = Files.list(responsePath)) {
|
||||||
|
epoch =
|
||||||
|
s.filter(Files::isRegularFile)
|
||||||
|
.map(
|
||||||
|
p -> {
|
||||||
|
Matcher matcher = pattern.matcher(p.getFileName().toString());
|
||||||
|
if (matcher.matches()) {
|
||||||
|
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
})
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.findFirst()
|
||||||
|
.orElse(null);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
|
||||||
|
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
|
||||||
|
|
||||||
|
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step1");
|
||||||
|
}
|
||||||
|
|
||||||
private Float parseFloatSafe(String value) {
|
private Float parseFloatSafe(String value) {
|
||||||
try {
|
try {
|
||||||
if (value == null) return null;
|
if (value == null) return null;
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.kamco.cd.training.common.enums.JobStatusType;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.DockerInspectState;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||||
|
import com.kamco.cd.training.train.dto.OutputResult;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
@@ -16,6 +20,7 @@ import java.nio.file.Paths;
|
|||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.NoSuchElementException;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
@@ -33,11 +38,14 @@ public class TrainJobService {
|
|||||||
|
|
||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
|
||||||
|
private final ModelTestMetricsJobService modelTestMetricsJobService;
|
||||||
private final DockerTrainService dockerTrainService;
|
private final DockerTrainService dockerTrainService;
|
||||||
private final ObjectMapper objectMapper;
|
private final ObjectMapper objectMapper;
|
||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
private final TmpDatasetService tmpDatasetService;
|
private final TmpDatasetService tmpDatasetService;
|
||||||
private final DataSetCountersService dataSetCounters;
|
private final DataSetCountersService dataSetCounters;
|
||||||
|
private final TrainUtilService trainUtilService;
|
||||||
|
|
||||||
// 학습 결과가 저장될 호스트 디렉토리
|
// 학습 결과가 저장될 호스트 디렉토리
|
||||||
@Value("${train.docker.response_dir}")
|
@Value("${train.docker.response_dir}")
|
||||||
@@ -309,4 +317,133 @@ public class TrainJobService {
|
|||||||
}
|
}
|
||||||
return modelUuid;
|
return modelUuid;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 동작되고잇는 pid을 확인하고 pid 목록에 있으면 진행중, 없으면 종료(결과 csv, zip 파일 확인하여 완료 여부 체크)
|
||||||
|
*
|
||||||
|
* @param uuid
|
||||||
|
* @param modelId
|
||||||
|
*/
|
||||||
|
public void status(UUID uuid, Long modelId) {
|
||||||
|
|
||||||
|
ModelTrainJobDto job =
|
||||||
|
modelTrainJobCoreService
|
||||||
|
.findLatestByModelId(modelId)
|
||||||
|
.orElseThrow(() -> new NoSuchElementException("job not found"));
|
||||||
|
|
||||||
|
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
// TODO 실행중 상태인것만 변경해야하면 주석 해제
|
||||||
|
// if(job.getStatusCd().equals(JobStatusType.RUNNING.getId())) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
String containerName = job.getContainerName();
|
||||||
|
|
||||||
|
try {
|
||||||
|
// docker inspect로 컨테이너 상태 조회
|
||||||
|
DockerInspectState state = trainUtilService.inspectContainer(containerName);
|
||||||
|
|
||||||
|
// 컨테이너가 "없음"
|
||||||
|
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
|
||||||
|
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
|
||||||
|
if (!state.exists()) {
|
||||||
|
log.warn("container missing. try file-based reconcile. container={}", containerName);
|
||||||
|
|
||||||
|
// 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
|
||||||
|
OutputResult out = trainUtilService.probeOutputs(job);
|
||||||
|
|
||||||
|
// 산출물이 충분하면 성공 처리
|
||||||
|
if (out.completed()) {
|
||||||
|
|
||||||
|
// 테스트 완료인지 zip파일로 확인
|
||||||
|
if (trainUtilService.existsZipFile(uuid)) {
|
||||||
|
// 테스트 완료일때
|
||||||
|
log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId());
|
||||||
|
|
||||||
|
// job 완료처리
|
||||||
|
modelTrainJobCoreService.markSuccess(job.getId(), 0);
|
||||||
|
|
||||||
|
// 학습 완료가 아니면 완료 업데이트
|
||||||
|
if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) {
|
||||||
|
|
||||||
|
// model 상태 변경 (학습)
|
||||||
|
modelTrainMngCoreService.markStep1Success(job.getModelId());
|
||||||
|
// 학습 결과 csv 파일 정보 등록
|
||||||
|
modelTrainMetricsJobService.trainValidMetricCsvFile(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 테스트 완료가 아니면 완료 업데이트
|
||||||
|
if (!model.getStep2Status().equals(TrainStatusType.COMPLETED.getId())) {
|
||||||
|
// model 상태 변경 (테스트)
|
||||||
|
modelTrainMngCoreService.markStep2Success(job.getModelId());
|
||||||
|
// 테스트 결과 csv 파일 정보 등록
|
||||||
|
modelTestMetricsJobService.testValidMetricCsvFiles(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// 학습 완료일때
|
||||||
|
log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId());
|
||||||
|
modelTrainJobCoreService.markSuccess(job.getId(), 0);
|
||||||
|
// 학습 완료가 아니면 완료 업데이트
|
||||||
|
if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) {
|
||||||
|
|
||||||
|
// model 상태 변경 (학습)
|
||||||
|
modelTrainMngCoreService.markStep1Success(job.getModelId());
|
||||||
|
// 학습 결과 csv 파일 정보 등록
|
||||||
|
modelTrainMetricsJobService.trainValidMetricCsvFile(modelId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// 산출물이 부족하면 중단처리
|
||||||
|
// 산출물이 부족하면 "중단/보류"로 처리
|
||||||
|
// 운영자가 재시작 할 수 있게 한다.
|
||||||
|
log.warn(
|
||||||
|
"outputs incomplete. mark PAUSED/STOP for restart. jobId={} reason={}",
|
||||||
|
job.getId(),
|
||||||
|
out.reason());
|
||||||
|
|
||||||
|
// PAUSED/STOP
|
||||||
|
modelTrainJobCoreService.markPaused(
|
||||||
|
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
|
||||||
|
|
||||||
|
// STOP으로 변경
|
||||||
|
markStepStopByJobType(
|
||||||
|
job, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE: " + out.reason());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 컨테이너가 있으면 진행중처리
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
|
||||||
|
if (isEval) {
|
||||||
|
// 테스트 진행중 상태로 업데이트
|
||||||
|
modelTrainMngCoreService.markStep2InProgress(job.getModelId(), job.getId());
|
||||||
|
} else {
|
||||||
|
// 학습 진행중 상태로 업데이트
|
||||||
|
modelTrainMngCoreService.markStep1InProgress(job.getModelId(), job.getId());
|
||||||
|
}
|
||||||
|
// job 테이블 진행중으로 업데이트
|
||||||
|
modelTrainJobCoreService.updateJobStatus(job.getId(), JobStatusType.RUNNING.getId());
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("container inspect failed. container={}", containerName, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
|
||||||
|
*
|
||||||
|
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
|
||||||
|
*/
|
||||||
|
private void markStepStopByJobType(ModelTrainJobDto job, String msg) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
|
||||||
|
if (isEval) {
|
||||||
|
modelTrainMngCoreService.markStep2Stop(job.getModelId(), msg);
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markStep1Stop(job.getModelId(), msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.DockerInspectState;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
|
import com.kamco.cd.training.train.dto.OutputResult;
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.*;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class TrainUtilService {
|
||||||
|
|
||||||
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
|
||||||
|
*
|
||||||
|
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
|
||||||
|
*/
|
||||||
|
@Value("${train.docker.response_dir}")
|
||||||
|
private String responseDir;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
|
||||||
|
*
|
||||||
|
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
|
||||||
|
*
|
||||||
|
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
|
||||||
|
*
|
||||||
|
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
|
||||||
|
* 타임아웃을 둠
|
||||||
|
*/
|
||||||
|
public DockerInspectState inspectContainer(String containerName)
|
||||||
|
throws IOException, InterruptedException {
|
||||||
|
|
||||||
|
ProcessBuilder pb =
|
||||||
|
new ProcessBuilder(
|
||||||
|
"docker",
|
||||||
|
"inspect",
|
||||||
|
"-f",
|
||||||
|
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
|
||||||
|
containerName);
|
||||||
|
|
||||||
|
pb.redirectErrorStream(true);
|
||||||
|
|
||||||
|
Process p = pb.start();
|
||||||
|
|
||||||
|
String output;
|
||||||
|
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
|
||||||
|
output = br.readLine();
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
|
||||||
|
if (!finished) {
|
||||||
|
p.destroyForcibly();
|
||||||
|
throw new IOException("docker inspect timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
int code = p.exitValue();
|
||||||
|
|
||||||
|
if (code != 0 || output == null || output.isBlank()) {
|
||||||
|
return DockerInspectState.missing();
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] parts = output.trim().split("\\s+");
|
||||||
|
|
||||||
|
String status = parts.length > 0 ? parts[0] : "unknown";
|
||||||
|
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
|
||||||
|
|
||||||
|
Integer exitCode = null;
|
||||||
|
if (parts.length > 2) {
|
||||||
|
try {
|
||||||
|
exitCode = Integer.parseInt(parts[2]);
|
||||||
|
} catch (Exception ignore) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DockerInspectState(true, running, exitCode, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
|
||||||
|
*
|
||||||
|
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
|
||||||
|
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
|
||||||
|
*
|
||||||
|
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
|
||||||
|
* 실패로 오판할 수 있음.
|
||||||
|
*/
|
||||||
|
public OutputResult probeOutputs(ModelTrainJobDto job) {
|
||||||
|
|
||||||
|
try {
|
||||||
|
Path outDir = resolveOutputDir(job);
|
||||||
|
if (outDir == null || !Files.isDirectory(outDir)) {
|
||||||
|
return new OutputResult(false, "output-dir-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
|
||||||
|
if (totalEpoch == null || totalEpoch <= 0) {
|
||||||
|
return new OutputResult(false, "total-epoch-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
Integer valInterval = extractValInterval(job).orElse(null);
|
||||||
|
if (valInterval == null || valInterval <= 0) {
|
||||||
|
return new OutputResult(false, "val-interval-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
Path valCsv = outDir.resolve("val.csv");
|
||||||
|
if (!Files.exists(valCsv)) {
|
||||||
|
return new OutputResult(false, "val.csv-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
long lines = countNonHeaderLines(valCsv);
|
||||||
|
int expectedLines = totalEpoch / valInterval;
|
||||||
|
|
||||||
|
if (lines >= expectedLines) {
|
||||||
|
return new OutputResult(true, "ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
return new OutputResult(false, "val.csv-lines-mismatch");
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
return new OutputResult(false, "probe-error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 테스트 완료후 zip 파일 있는지 확인
|
||||||
|
*
|
||||||
|
* @param uuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public boolean existsZipFile(UUID uuid) {
|
||||||
|
Path path = Paths.get(responseDir, uuid.toString());
|
||||||
|
|
||||||
|
if (!Files.isDirectory(path)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
String pattern = "*" + uuid + "*.zip";
|
||||||
|
|
||||||
|
try (DirectoryStream<Path> stream = Files.newDirectoryStream(path, pattern)) {
|
||||||
|
return stream.iterator().hasNext();
|
||||||
|
} catch (IOException e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
|
||||||
|
*
|
||||||
|
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
|
||||||
|
*
|
||||||
|
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
|
||||||
|
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
|
||||||
|
*
|
||||||
|
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
|
||||||
|
* 형태를 1순위로 두세요.
|
||||||
|
*/
|
||||||
|
private Path resolveOutputDir(ModelTrainJobDto job) {
|
||||||
|
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
|
||||||
|
|
||||||
|
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
|
||||||
|
|
||||||
|
return Files.isDirectory(base) ? base : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* paramsJson에서 total_epoch 값을 추출합니다.
|
||||||
|
*
|
||||||
|
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
|
||||||
|
*
|
||||||
|
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
|
||||||
|
*/
|
||||||
|
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
if (params == null) return Optional.empty();
|
||||||
|
|
||||||
|
Object v = params.get("total_epoch");
|
||||||
|
if (v == null) v = params.get("totalEpoch");
|
||||||
|
|
||||||
|
try {
|
||||||
|
return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v)));
|
||||||
|
} catch (Exception e) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** paramsJson에서 valInterval 추출 */
|
||||||
|
private Optional<Integer> extractValInterval(ModelTrainJobDto job) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
if (params == null) return Optional.empty();
|
||||||
|
|
||||||
|
Object v = params.get("valInterval");
|
||||||
|
|
||||||
|
try {
|
||||||
|
return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v)));
|
||||||
|
} catch (Exception e) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
|
||||||
|
*
|
||||||
|
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
|
||||||
|
*
|
||||||
|
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
|
||||||
|
*/
|
||||||
|
private long countNonHeaderLines(Path csv) throws IOException {
|
||||||
|
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
|
||||||
|
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user