Merge pull request '상태변경 추가' (#191) from feat/training_260324 into develop

Reviewed-on: #191
This commit was merged in pull request #191.
This commit is contained in:
2026-04-06 21:33:29 +09:00
21 changed files with 2180 additions and 525 deletions

View File

@@ -137,6 +137,6 @@ public class SecurityConfig {
/** 완전 제외(필터 자체를 안 탐) */
@Bean
public WebSecurityCustomizer webSecurityCustomizer() {
return (web) -> web.ignoring().requestMatchers("/api/mapsheet/**");
return (web) -> web.ignoring().requestMatchers("/api/mapsheet/**", "/api/file-manager/**");
}
}

View File

@@ -0,0 +1,125 @@
package com.kamco.cd.training.filemanager;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.filemanager.dto.FileManagerDto;
import com.kamco.cd.training.filemanager.service.FileManagerService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.ExampleObject;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Slf4j
@Tag(name = "파일 관리", description = "/data 디렉토리 파일 관리 API")
@RestController
@RequestMapping("/api/file-manager")
@RequiredArgsConstructor
public class FileManagerApiController {
private final FileManagerService fileManagerService;
@Operation(
summary = "파일 목록 조회",
description = "/data 디렉토리 내 파일 및 디렉토리 목록을 조회합니다. recursive=true로 설정하면 하위 디렉토리까지 조회합니다.")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.ListFilesRes.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청 (유효하지 않은 경로)", content = @Content),
@ApiResponse(responseCode = "404", description = "디렉토리를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/files")
public ApiResponseDto<FileManagerDto.ListFilesRes> listFiles(
@Parameter(description = "조회할 디렉토리 경로 (기본값: /data)", example = "/data/request")
@RequestParam(required = false)
String directoryPath,
@Parameter(description = "하위 디렉토리 포함 여부", example = "false")
@RequestParam(required = false, defaultValue = "false")
Boolean recursive) {
FileManagerDto.ListFilesReq request =
FileManagerDto.ListFilesReq.builder()
.directoryPath(directoryPath)
.recursive(recursive)
.build();
FileManagerDto.ListFilesRes response = fileManagerService.listFiles(request);
return ApiResponseDto.ok(response);
}
@Operation(
summary = "파일/디렉토리 삭제",
description = "지정된 파일 또는 디렉토리를 삭제합니다. recursive=true로 설정하면 디렉토리 내 모든 파일을 삭제합니다.",
requestBody =
@io.swagger.v3.oas.annotations.parameters.RequestBody(
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.DeleteFileReq.class),
examples = {
@ExampleObject(
name = "단일 파일 삭제",
value =
"""
{
"filePaths": ["/data/request/old_file.zip"],
"recursive": false
}
"""),
@ExampleObject(
name = "여러 파일 삭제",
value =
"""
{
"filePaths": ["/data/file1.txt", "/data/file2.txt"],
"recursive": false
}
"""),
@ExampleObject(
name = "디렉토리 전체 삭제",
value =
"""
{
"filePaths": ["/data/old_folder"],
"recursive": true
}
""")
})))
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "삭제 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.DeleteFileRes.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청 (유효하지 않은 경로)", content = @Content),
@ApiResponse(responseCode = "404", description = "파일을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/files")
public ApiResponseDto<FileManagerDto.DeleteFileRes> deleteFiles(
@RequestBody FileManagerDto.DeleteFileReq request) {
FileManagerDto.DeleteFileRes response = fileManagerService.deleteFiles(request);
return ApiResponseDto.ok(response);
}
}

View File

@@ -0,0 +1,134 @@
package com.kamco.cd.training.filemanager.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDateTime;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
public class FileManagerDto {
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 정보")
public static class FileInfo {
@Schema(description = "파일명", example = "dataset.zip")
private String fileName;
@Schema(description = "파일 전체 경로", example = "/data/request/dataset.zip")
private String filePath;
@Schema(description = "파일 크기 (bytes)", example = "1024000")
private Long fileSize;
@Schema(description = "파일인지 디렉토리인지 여부", example = "true")
private Boolean isFile;
@Schema(description = "디렉토리인지 여부", example = "false")
private Boolean isDirectory;
@Schema(description = "마지막 수정 시간", example = "2026-04-06T15:30:00")
private LocalDateTime lastModified;
@Schema(description = "읽기 권한", example = "true")
private Boolean readable;
@Schema(description = "쓰기 권한", example = "true")
private Boolean writable;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디렉토리 목록 조회 요청")
public static class ListFilesReq {
@Schema(description = "조회할 디렉토리 경로 (기본값: /data)", example = "/data/request")
private String directoryPath;
@Schema(description = "하위 디렉토리 포함 여부", example = "false")
private Boolean recursive;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디렉토리 목록 조회 응답")
public static class ListFilesRes {
@Schema(description = "조회된 디렉토리 경로", example = "/data/request")
private String directoryPath;
@Schema(description = "파일 목록")
private List<FileInfo> files;
@Schema(description = "총 파일 개수", example = "10")
private Integer totalCount;
@Schema(description = "총 파일 크기 (bytes)", example = "10240000")
private Long totalSize;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 삭제 요청")
public static class DeleteFileReq {
@Schema(
description = "삭제할 파일 또는 디렉토리 경로 목록",
example = "[\"/data/request/old_file.zip\", \"/data/tmp/test_folder\"]")
private List<String> filePaths;
@Schema(description = "디렉토리일 경우 하위 파일 포함 삭제 여부", example = "true")
private Boolean recursive;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 삭제 응답")
public static class DeleteFileRes {
@Schema(description = "삭제 성공한 파일 경로 목록")
private List<String> deletedFiles;
@Schema(description = "삭제 실패한 파일 경로 목록")
private List<String> failedFiles;
@Schema(description = "삭제 성공 개수", example = "5")
private Integer successCount;
@Schema(description = "삭제 실패 개수", example = "0")
private Integer failureCount;
@Schema(description = "전체 메시지", example = "5개 파일 삭제 성공")
private String message;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디스크 사용량 조회 응답")
public static class DiskUsageRes {
@Schema(description = "디렉토리 경로", example = "/data")
private String directoryPath;
@Schema(description = "총 용량 (bytes)", example = "1000000000000")
private Long totalSpace;
@Schema(description = "사용 가능 용량 (bytes)", example = "500000000000")
private Long usableSpace;
@Schema(description = "사용 중인 용량 (bytes)", example = "500000000000")
private Long usedSpace;
@Schema(description = "사용률 (%)", example = "50.0")
private Double usagePercentage;
}
}

View File

@@ -0,0 +1,234 @@
package com.kamco.cd.training.filemanager.service;
import com.kamco.cd.training.filemanager.dto.FileManagerDto;
import java.io.IOException;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class FileManagerService {
private static final String BASE_DATA_PATH = "/data";
private static final long MAX_PATH_LENGTH = 500;
/**
* 디렉토리 내 파일 목록 조회
*
* @param request 조회 요청 정보
* @return 파일 목록 응답
*/
public FileManagerDto.ListFilesRes listFiles(FileManagerDto.ListFilesReq request) {
String targetPath =
request.getDirectoryPath() != null ? request.getDirectoryPath() : BASE_DATA_PATH;
boolean recursive = request.getRecursive() != null && request.getRecursive();
validatePath(targetPath);
Path directory = Paths.get(targetPath);
if (!Files.exists(directory)) {
throw new IllegalArgumentException("디렉토리가 존재하지 않습니다: " + targetPath);
}
if (!Files.isDirectory(directory)) {
throw new IllegalArgumentException("디렉토리 경로가 아닙니다: " + targetPath);
}
List<FileManagerDto.FileInfo> files = new ArrayList<>();
long totalSize = 0;
try {
if (recursive) {
// 재귀적으로 모든 하위 파일 조회
Files.walkFileTree(
directory,
new SimpleFileVisitor<>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
files.add(createFileInfo(file));
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) {
if (!dir.equals(directory)) {
files.add(createFileInfo(dir));
}
return FileVisitResult.CONTINUE;
}
});
} else {
// 현재 디렉토리의 파일만 조회
try (Stream<Path> stream = Files.list(directory)) {
stream.forEach(path -> files.add(createFileInfo(path)));
}
}
// 총 파일 크기 계산
for (FileManagerDto.FileInfo file : files) {
if (file.getIsFile() && file.getFileSize() != null) {
totalSize += file.getFileSize();
}
}
} catch (IOException e) {
log.error("파일 목록 조회 중 오류 발생: {}", targetPath, e);
throw new RuntimeException("파일 목록 조회에 실패했습니다: " + e.getMessage());
}
return FileManagerDto.ListFilesRes.builder()
.directoryPath(targetPath)
.files(files)
.totalCount(files.size())
.totalSize(totalSize)
.build();
}
/**
* 파일 또는 디렉토리 삭제
*
* @param request 삭제 요청 정보
* @return 삭제 결과
*/
public FileManagerDto.DeleteFileRes deleteFiles(FileManagerDto.DeleteFileReq request) {
List<String> deletedFiles = new ArrayList<>();
List<String> failedFiles = new ArrayList<>();
boolean recursive = request.getRecursive() != null && request.getRecursive();
for (String filePath : request.getFilePaths()) {
try {
validatePath(filePath);
Path path = Paths.get(filePath);
if (!Files.exists(path)) {
log.warn("삭제하려는 파일이 존재하지 않습니다: {}", filePath);
failedFiles.add(filePath + " (파일이 존재하지 않음)");
continue;
}
if (Files.isDirectory(path)) {
if (recursive) {
// 디렉토리 및 하위 파일 모두 삭제
deleteDirectoryRecursively(path);
deletedFiles.add(filePath);
} else {
// 빈 디렉토리만 삭제
if (isDirectoryEmpty(path)) {
Files.delete(path);
deletedFiles.add(filePath);
} else {
failedFiles.add(filePath + " (디렉토리가 비어있지 않음)");
}
}
} else {
// 파일 삭제
Files.delete(path);
deletedFiles.add(filePath);
}
log.info("파일 삭제 성공: {}", filePath);
} catch (Exception e) {
log.error("파일 삭제 실패: {}", filePath, e);
failedFiles.add(filePath + " (" + e.getMessage() + ")");
}
}
String message =
String.format("%d개 파일 삭제 성공, %d개 파일 삭제 실패", deletedFiles.size(), failedFiles.size());
return FileManagerDto.DeleteFileRes.builder()
.deletedFiles(deletedFiles)
.failedFiles(failedFiles)
.successCount(deletedFiles.size())
.failureCount(failedFiles.size())
.message(message)
.build();
}
/** FileInfo 객체 생성 */
private FileManagerDto.FileInfo createFileInfo(Path path) {
try {
BasicFileAttributes attrs = Files.readAttributes(path, BasicFileAttributes.class);
return FileManagerDto.FileInfo.builder()
.fileName(path.getFileName().toString())
.filePath(path.toString())
.fileSize(attrs.isRegularFile() ? attrs.size() : null)
.isFile(attrs.isRegularFile())
.isDirectory(attrs.isDirectory())
.lastModified(
LocalDateTime.ofInstant(
Instant.ofEpochMilli(attrs.lastModifiedTime().toMillis()),
ZoneId.systemDefault()))
.readable(Files.isReadable(path))
.writable(Files.isWritable(path))
.build();
} catch (IOException e) {
log.warn("파일 정보 조회 실패: {}", path, e);
return FileManagerDto.FileInfo.builder()
.fileName(path.getFileName().toString())
.filePath(path.toString())
.build();
}
}
/** 디렉토리 재귀 삭제 */
private void deleteDirectoryRecursively(Path directory) throws IOException {
Files.walkFileTree(
directory,
new SimpleFileVisitor<>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs)
throws IOException {
Files.delete(file);
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
Files.delete(dir);
return FileVisitResult.CONTINUE;
}
});
}
/** 디렉토리가 비어있는지 확인 */
private boolean isDirectoryEmpty(Path directory) throws IOException {
try (Stream<Path> stream = Files.list(directory)) {
return stream.findFirst().isEmpty();
}
}
/** 경로 검증 (보안) */
private void validatePath(String path) {
if (path == null || path.trim().isEmpty()) {
throw new IllegalArgumentException("경로가 비어있습니다");
}
if (path.length() > MAX_PATH_LENGTH) {
throw new IllegalArgumentException("경로가 너무 깁니다");
}
// 경로 순회 공격 방지 - 상대경로 패턴만 제한
if (path.contains("..")) {
throw new IllegalArgumentException("상대 경로(..)는 사용할 수 없습니다");
}
}
}

View File

@@ -82,6 +82,10 @@ public class ModelTrainMngService {
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 없음");
}
if (model.getRequestPath() == null) {
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "임시파일 경로 없음");
}
// ===== 2. 경로 생성 =====
Path tmpBase = Path.of(symbolicDir).toAbsolutePath().normalize();
Path tmp = tmpBase.resolve(model.getRequestPath()).normalize();

View File

@@ -26,6 +26,10 @@ public class ModelTestMetricsJobCoreService {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
}
public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelId(modelId);
}
public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
}

View File

@@ -80,6 +80,17 @@ public class ModelTrainJobCoreService {
}
}
/** 실행 시작 처리 수정 */
@Transactional
public void updateJobStatus(Long jobId, String jobStatus) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd(jobStatus);
}
/**
* 성공 처리
*

View File

@@ -17,6 +17,10 @@ public class ModelTrainMetricsJobCoreService {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
}
public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelId(modelId);
}
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
}

View File

@@ -12,6 +12,8 @@ public interface ModelTestMetricsJobRepositoryCustom {
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId);
void insertModelMetricsTest(List<Object[]> batchArgs);
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);

View File

@@ -63,6 +63,25 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
.fetch();
}
@Override
public ResponsePathDto getTestMetricSaveNotYetModelId(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.id.eq(modelId),
modelMasterEntity
.step2MetricSaveYn
.isNull()
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
.fetchOne();
}
@Override
public void insertModelMetricsTest(List<Object[]> batchArgs) {
// AS-IS

View File

@@ -7,6 +7,8 @@ public interface ModelTrainMetricsJobRepositoryCustom {
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId);
void insertModelMetricsTrain(List<Object[]> batchArgs);
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);

View File

@@ -44,6 +44,25 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
.fetch();
}
@Override
public ResponsePathDto getTrainMetricSaveNotYetModelId(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.id.eq(modelId),
modelMasterEntity
.step1MetricSaveYn
.isNull()
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
.fetchOne();
}
@Override
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql =

View File

@@ -213,4 +213,27 @@ public class TrainApiController {
Long modelId = trainJobService.getModelIdByUuid(uuid);
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
}
@Operation(summary = "학습 상태 확인", description = "학습 상태 확인")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "학습 상태 변경",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping(path = "/status/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE)
public ApiResponseDto<String> status(
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.status(uuid, modelId);
return ApiResponseDto.ok("ok");
}
}

View File

@@ -0,0 +1,8 @@
package com.kamco.cd.training.train.dto;
public record DockerInspectState(boolean exists, boolean running, Integer exitCode, String status) {
public static DockerInspectState missing() {
return new DockerInspectState(false, false, null, "missing");
}
}

View File

@@ -0,0 +1,20 @@
package com.kamco.cd.training.train.dto;
public class OutputResult {
private final boolean completed;
private final String reason;
public OutputResult(boolean completed, String reason) {
this.completed = completed;
this.reason = reason;
}
public boolean completed() {
return completed;
}
public String reason() {
return reason;
}
}

View File

@@ -1,25 +1,16 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.DockerInspectState;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import com.kamco.cd.training.train.dto.OutputResult;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.annotation.Profile;
import org.springframework.context.event.EventListener;
@@ -44,14 +35,7 @@ public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
/**
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
*
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
*/
@Value("${train.docker.response_dir}")
private String responseDir;
private final TrainUtilService trainUtilService;
/**
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
@@ -77,7 +61,7 @@ public class JobRecoveryOnStartupService {
try {
// 2-1) docker inspect로 컨테이너 상태 조회
DockerInspectState state = inspectContainer(containerName);
DockerInspectState state = trainUtilService.inspectContainer(containerName);
// 3) 컨테이너가 "없음"
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
@@ -88,7 +72,7 @@ public class JobRecoveryOnStartupService {
containerName);
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
OutputResult out = probeOutputs(job);
OutputResult out = trainUtilService.probeOutputs(job);
// 3-2) 산출물이 충분하면 성공 처리
if (out.completed()) {
@@ -109,11 +93,9 @@ public class JobRecoveryOnStartupService {
job.getId(),
out.reason());
Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId());
// PAUSED/STOP
modelTrainJobCoreService.markPaused(
job.getId(), modelId, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
// 모델도 에러가 아니라 STOP으로
markStepStopByJobType(
@@ -152,7 +134,7 @@ public class JobRecoveryOnStartupService {
// ============================================================
// 2) kill 후 실제로 죽었는지 확인
// ============================================================
DockerInspectState after = inspectContainer(containerName);
DockerInspectState after = trainUtilService.inspectContainer(containerName);
if (after.exists() && after.running()) {
throw new IOException("docker kill returned 0 but container still running");
}
@@ -162,10 +144,8 @@ public class JobRecoveryOnStartupService {
// ============================================================
// 3) job 상태를 PAUSED로 변경 (서버 재기동으로 강제 중단)
// ============================================================
Integer modelId = job.getModelId() == null ? null : Math.toIntExact(job.getModelId());
modelTrainJobCoreService.markPaused(
job.getId(), modelId, "AUTO_KILLED_ON_SERVER_RESTART");
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_KILLED_ON_SERVER_RESTART");
log.info("job = {}", job);
markStepStopByJobType(job, "AUTO_KILLED_ON_SERVER_RESTART");
@@ -264,301 +244,4 @@ public class JobRecoveryOnStartupService {
modelTrainMngCoreService.markError(job.getModelId(), msg);
}
}
/**
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
*
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
*
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
*
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
* 타임아웃을 둠
*/
private DockerInspectState inspectContainer(String containerName)
throws IOException, InterruptedException {
ProcessBuilder pb =
new ProcessBuilder(
"docker",
"inspect",
"-f",
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
containerName);
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
pb.redirectErrorStream(true);
Process p = pb.start();
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
String output;
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
output = br.readLine();
}
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker inspect timeout");
}
// docker inspect 자체의 프로세스 exit code
int code = p.exitValue();
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
if (code != 0 || output == null || output.isBlank()) {
return DockerInspectState.missing();
}
// "status running exitCode" 형태로 split
String[] parts = output.trim().split("\\s+");
// status: running/exited/dead 등
String status = parts.length > 0 ? parts[0] : "unknown";
// running: true/false
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
// exitCode: 정수 파싱(파싱 실패하면 null)
Integer exitCode = null;
if (parts.length > 2) {
try {
exitCode = Integer.parseInt(parts[2]);
} catch (Exception ignore) {
// ignore
}
}
return new DockerInspectState(true, running, exitCode, status);
}
/**
* docker inspect 결과를 담는 레코드.
*
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
*/
private record DockerInspectState(
boolean exists, boolean running, Integer exitCode, String status) {
static DockerInspectState missing() {
return new DockerInspectState(false, false, null, "missing");
}
}
// ============================================================================================
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
*
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
*
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
* 실패로 오판할 수 있음.
*/
private OutputResult probeOutputs(ModelTrainJobDto job) {
try {
log.info(
"[RECOVERY] probeOutputs start. jobId={}, modelId={}", job.getId(), job.getModelId());
// 1) 출력 디렉토리 확인
Path outDir = resolveOutputDir(job);
if (outDir == null || !Files.isDirectory(outDir)) {
log.warn("[RECOVERY] output directory missing. jobId={}, path={}", job.getId(), outDir);
return new OutputResult(false, "output-dir-missing");
}
log.info("[RECOVERY] output directory found. jobId={}, path={}", job.getId(), outDir);
// 2) totalEpoch 확인
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
if (totalEpoch == null || totalEpoch <= 0) {
log.warn(
"[RECOVERY] totalEpoch missing or invalid. jobId={}, totalEpoch={}",
job.getId(),
totalEpoch);
return new OutputResult(false, "total-epoch-missing");
}
Integer valInterval = extractValInterval(job).orElse(null);
if (valInterval == null || valInterval <= 0) {
log.warn(
"[RECOVERY] valInterval missing or invalid. jobId={}, valInterval={}",
job.getId(),
valInterval);
return new OutputResult(false, "val-interval-missing");
}
log.info(
"[RECOVERY] totalEpoch={}. valInterval={}. jobId={}",
totalEpoch,
valInterval,
job.getId());
// 3) val.csv 존재 확인
Path valCsv = outDir.resolve("val.csv");
if (!Files.exists(valCsv)) {
log.warn("[RECOVERY] val.csv missing. jobId={}, path={}", job.getId(), valCsv);
return new OutputResult(false, "val.csv-missing");
}
// 4) val.csv 라인 수 확인
long lines = countNonHeaderLines(valCsv);
// expected = 실제 val 실행 횟수
int expectedLines = totalEpoch / valInterval;
log.info(
"[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}",
job.getId(),
lines,
expectedLines);
// 5) 완료 판정
if (lines >= expectedLines) {
log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId());
return new OutputResult(true, "ok");
}
log.warn(
"[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}",
job.getId(),
lines,
expectedLines);
return new OutputResult(
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
} catch (Exception e) {
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
return new OutputResult(false, "probe-error");
}
}
/**
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
*
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
*
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
*
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
* 형태를 1순위로 두세요.
*/
private Path resolveOutputDir(ModelTrainJobDto job) {
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
return Files.isDirectory(base) ? base : null;
}
/**
* paramsJson에서 total_epoch 값을 추출합니다.
*
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
*
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
*/
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("total_epoch");
if (v == null) v = params.get("totalEpoch");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
/**
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
*
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
*
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
*/
private long countNonHeaderLines(Path csv) throws IOException {
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
}
}
/**
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
*
* <p>예: - "*.pth" - "best*.pth"
*/
private long countFilesByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
long cnt = 0;
for (Path p : ds) {
if (Files.isRegularFile(p)) cnt++;
}
return cnt;
}
}
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
private boolean existsByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
return ds.iterator().hasNext();
}
}
// ============================================================================================
// probeOutputs() 결과 객체
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
*
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
*
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
*/
private static final class OutputResult {
private final boolean completed;
private final String reason;
private OutputResult(boolean completed, String reason) {
this.completed = completed;
this.reason = reason;
}
boolean completed() {
return completed;
}
String reason() {
return reason;
}
}
/** paramsJson에서 valInterval 추출 */
private Optional<Integer> extractValInterval(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("valInterval");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
}

View File

@@ -59,119 +59,133 @@ public class ModelTestMetricsJobService {
}
for (ResponsePathDto modelInfo : modelIds) {
createFile(modelInfo);
}
}
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
/** 단건 결과 csv 파일 정보 등록 */
public void testValidMetricCsvFiles(Long modelId) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
ResponsePathDto model = modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelId(modelId);
List<Object[]> batchArgs = new ArrayList<>();
if (model == null) {
return;
}
for (CSVRecord record : parser) {
createFile(model);
}
String model = record.get("model");
long TP = Long.parseLong(record.get("TP"));
long FP = Long.parseLong(record.get("FP"));
long FN = Long.parseLong(record.get("FN"));
float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
float f1_score = Float.parseFloat(record.get("f1_score"));
float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
/**
* 베스트 에폭 zip파일 생성, 테스트결과 db등록
*
* @param modelInfo
*/
private void createFile(ResponsePathDto modelInfo) {
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
model,
TP,
FP,
FN,
precision,
recall,
f1_score,
accuracy,
iou,
detection_count,
gt_count
});
}
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
// test.csv 파일 읽어서 저장한 여부로만 사용하기
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step2");
List<Object[]> batchArgs = new ArrayList<>();
} catch (IOException e) {
throw new RuntimeException(e);
for (CSVRecord record : parser) {
String model = record.get("model");
long TP = Long.parseLong(record.get("TP"));
long FP = Long.parseLong(record.get("FP"));
long FN = Long.parseLong(record.get("FN"));
float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
float f1_score = Float.parseFloat(record.get("f1_score"));
float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
model,
TP,
FP,
FN,
precision,
recall,
f1_score,
accuracy,
iou,
detection_count,
gt_count
});
}
// 패키징할 파일 만들기
modelTestMetricsJobCoreService.updatePackingStart(
modelInfo.getModelId(), ZonedDateTime.now());
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir
+ "/"
+ modelInfo.getUuid()
+ "/"
+ jsonDto.getModelVersion()
+ ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
// test.csv 파일 읽어서 저장한 여부로만 사용하기
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
} catch (IOException e) {
throw new RuntimeException(e);
}
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
// 패키징할 파일 만들기
modelTestMetricsJobCoreService.updatePackingStart(modelInfo.getModelId(), ZonedDateTime.now());
Path zipPath =
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
responseDir + "/" + modelInfo.getUuid() + "/" + jsonDto.getModelVersion() + ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
try {
zipFiles(files, zipPath);
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
} catch (IOException e) {
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e);
}
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
zipFiles(files, zipPath);
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
} catch (IOException e) {
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e);
}
}

View File

@@ -48,115 +48,135 @@ public class ModelTrainMetricsJobService {
for (ResponsePathDto modelInfo : modelIds) {
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
}
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
Float aAcc = parseFloatSafe(record.get("aAcc"));
Float mFscore = parseFloatSafe(record.get("mFscore"));
Float mPrecision = parseFloatSafe(record.get("mPrecision"));
Float mRecall = parseFloatSafe(record.get("mRecall"));
Float mIoU = parseFloatSafe(record.get("mIoU"));
Float mAcc = parseFloatSafe(record.get("mAcc"));
Float changed_fscore = parseFloatSafe(record.get("changed_fscore"));
Float changed_precision = parseFloatSafe(record.get("changed_precision"));
Float changed_recall = parseFloatSafe(record.get("changed_recall"));
Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore"));
Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision"));
Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
}
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
Integer epoch = null;
boolean exists;
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
try (Stream<Path> s = Files.list(responsePath)) {
epoch =
s.filter(Files::isRegularFile)
.map(
p -> {
Matcher matcher = pattern.matcher(p.getFileName().toString());
if (matcher.matches()) {
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
createFile(modelInfo);
}
}
/** 단건 결과 csv 파일 정보 등록 */
public void trainValidMetricCsvFile(Long modelId) {
ResponsePathDto modelInfo =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelId(modelId);
if (modelInfo == null) {
return;
}
createFile(modelInfo);
}
/**
* 학습 csv 파일 db 등록
*
* @param modelInfo
*/
private void createFile(ResponsePathDto modelInfo) {
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
}
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
Float aAcc = parseFloatSafe(record.get("aAcc"));
Float mFscore = parseFloatSafe(record.get("mFscore"));
Float mPrecision = parseFloatSafe(record.get("mPrecision"));
Float mRecall = parseFloatSafe(record.get("mRecall"));
Float mIoU = parseFloatSafe(record.get("mIoU"));
Float mAcc = parseFloatSafe(record.get("mAcc"));
Float changed_fscore = parseFloatSafe(record.get("changed_fscore"));
Float changed_precision = parseFloatSafe(record.get("changed_precision"));
Float changed_recall = parseFloatSafe(record.get("changed_recall"));
Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore"));
Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision"));
Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
}
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
Integer epoch = null;
boolean exists;
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
try (Stream<Path> s = Files.list(responsePath)) {
epoch =
s.filter(Files::isRegularFile)
.map(
p -> {
Matcher matcher = pattern.matcher(p.getFileName().toString());
if (matcher.matches()) {
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step1");
}
private Float parseFloatSafe(String value) {
try {
if (value == null) return null;

View File

@@ -1,13 +1,17 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.DockerInspectState;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.OutputResult;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException;
import java.nio.file.Files;
@@ -16,6 +20,7 @@ import java.nio.file.Paths;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
@@ -33,11 +38,14 @@ public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
private final TmpDatasetService tmpDatasetService;
private final DataSetCountersService dataSetCounters;
private final TrainUtilService trainUtilService;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.response_dir}")
@@ -309,4 +317,133 @@ public class TrainJobService {
}
return modelUuid;
}
/**
* 동작되고잇는 pid을 확인하고 pid 목록에 있으면 진행중, 없으면 종료(결과 csv, zip 파일 확인하여 완료 여부 체크)
*
* @param uuid
* @param modelId
*/
public void status(UUID uuid, Long modelId) {
ModelTrainJobDto job =
modelTrainJobCoreService
.findLatestByModelId(modelId)
.orElseThrow(() -> new NoSuchElementException("job not found"));
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(modelId);
// TODO 실행중 상태인것만 변경해야하면 주석 해제
// if(job.getStatusCd().equals(JobStatusType.RUNNING.getId())) {
// return;
// }
String containerName = job.getContainerName();
try {
// docker inspect로 컨테이너 상태 조회
DockerInspectState state = trainUtilService.inspectContainer(containerName);
// 컨테이너가 "없음"
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
if (!state.exists()) {
log.warn("container missing. try file-based reconcile. container={}", containerName);
// 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
OutputResult out = trainUtilService.probeOutputs(job);
// 산출물이 충분하면 성공 처리
if (out.completed()) {
// 테스트 완료인지 zip파일로 확인
if (trainUtilService.existsZipFile(uuid)) {
// 테스트 완료일때
log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId());
// job 완료처리
modelTrainJobCoreService.markSuccess(job.getId(), 0);
// 학습 완료가 아니면 완료 업데이트
if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) {
// model 상태 변경 (학습)
modelTrainMngCoreService.markStep1Success(job.getModelId());
// 학습 결과 csv 파일 정보 등록
modelTrainMetricsJobService.trainValidMetricCsvFile(modelId);
}
// 테스트 완료가 아니면 완료 업데이트
if (!model.getStep2Status().equals(TrainStatusType.COMPLETED.getId())) {
// model 상태 변경 (테스트)
modelTrainMngCoreService.markStep2Success(job.getModelId());
// 테스트 결과 csv 파일 정보 등록
modelTestMetricsJobService.testValidMetricCsvFiles(modelId);
}
} else {
// 학습 완료일때
log.info("outputs look completed. mark SUCCESS. jobId={}", job.getId());
modelTrainJobCoreService.markSuccess(job.getId(), 0);
// 학습 완료가 아니면 완료 업데이트
if (!model.getStep1Status().equals(TrainStatusType.COMPLETED.getId())) {
// model 상태 변경 (학습)
modelTrainMngCoreService.markStep1Success(job.getModelId());
// 학습 결과 csv 파일 정보 등록
modelTrainMetricsJobService.trainValidMetricCsvFile(modelId);
}
}
} else {
// 산출물이 부족하면 중단처리
// 산출물이 부족하면 "중단/보류"로 처리
// 운영자가 재시작 할 수 있게 한다.
log.warn(
"outputs incomplete. mark PAUSED/STOP for restart. jobId={} reason={}",
job.getId(),
out.reason());
// PAUSED/STOP
modelTrainJobCoreService.markPaused(
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
// STOP으로 변경
markStepStopByJobType(
job, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE: " + out.reason());
}
} else {
// 컨테이너가 있으면 진행중처리
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
// 테스트 진행중 상태로 업데이트
modelTrainMngCoreService.markStep2InProgress(job.getModelId(), job.getId());
} else {
// 학습 진행중 상태로 업데이트
modelTrainMngCoreService.markStep1InProgress(job.getModelId(), job.getId());
}
// job 테이블 진행중으로 업데이트
modelTrainJobCoreService.updateJobStatus(job.getId(), JobStatusType.RUNNING.getId());
}
} catch (Exception e) {
log.error("container inspect failed. container={}", containerName, e);
}
}
/**
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
*/
private void markStepStopByJobType(ModelTrainJobDto job, String msg) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Stop(job.getModelId(), msg);
} else {
modelTrainMngCoreService.markStep1Stop(job.getModelId(), msg);
}
}
}

View File

@@ -0,0 +1,228 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.DockerInspectState;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.OutputResult;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
public class TrainUtilService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
/**
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
*
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
*/
@Value("${train.docker.response_dir}")
private String responseDir;
/**
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
*
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
*
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
*
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
* 타임아웃을 둠
*/
public DockerInspectState inspectContainer(String containerName)
throws IOException, InterruptedException {
ProcessBuilder pb =
new ProcessBuilder(
"docker",
"inspect",
"-f",
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
containerName);
pb.redirectErrorStream(true);
Process p = pb.start();
String output;
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
output = br.readLine();
}
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker inspect timeout");
}
int code = p.exitValue();
if (code != 0 || output == null || output.isBlank()) {
return DockerInspectState.missing();
}
String[] parts = output.trim().split("\\s+");
String status = parts.length > 0 ? parts[0] : "unknown";
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
Integer exitCode = null;
if (parts.length > 2) {
try {
exitCode = Integer.parseInt(parts[2]);
} catch (Exception ignore) {
}
}
return new DockerInspectState(true, running, exitCode, status);
}
/**
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
*
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
*
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
* 실패로 오판할 수 있음.
*/
public OutputResult probeOutputs(ModelTrainJobDto job) {
try {
Path outDir = resolveOutputDir(job);
if (outDir == null || !Files.isDirectory(outDir)) {
return new OutputResult(false, "output-dir-missing");
}
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
if (totalEpoch == null || totalEpoch <= 0) {
return new OutputResult(false, "total-epoch-missing");
}
Integer valInterval = extractValInterval(job).orElse(null);
if (valInterval == null || valInterval <= 0) {
return new OutputResult(false, "val-interval-missing");
}
Path valCsv = outDir.resolve("val.csv");
if (!Files.exists(valCsv)) {
return new OutputResult(false, "val.csv-missing");
}
long lines = countNonHeaderLines(valCsv);
int expectedLines = totalEpoch / valInterval;
if (lines >= expectedLines) {
return new OutputResult(true, "ok");
}
return new OutputResult(false, "val.csv-lines-mismatch");
} catch (Exception e) {
return new OutputResult(false, "probe-error");
}
}
/**
* 테스트 완료후 zip 파일 있는지 확인
*
* @param uuid
* @return
*/
public boolean existsZipFile(UUID uuid) {
Path path = Paths.get(responseDir, uuid.toString());
if (!Files.isDirectory(path)) {
return false;
}
String pattern = "*" + uuid + "*.zip";
try (DirectoryStream<Path> stream = Files.newDirectoryStream(path, pattern)) {
return stream.iterator().hasNext();
} catch (IOException e) {
return false;
}
}
/**
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
*
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
*
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
*
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
* 형태를 1순위로 두세요.
*/
private Path resolveOutputDir(ModelTrainJobDto job) {
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
return Files.isDirectory(base) ? base : null;
}
/**
* paramsJson에서 total_epoch 값을 추출합니다.
*
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
*
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
*/
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("total_epoch");
if (v == null) v = params.get("totalEpoch");
try {
return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception e) {
return Optional.empty();
}
}
/** paramsJson에서 valInterval 추출 */
private Optional<Integer> extractValInterval(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("valInterval");
try {
return v == null ? Optional.empty() : Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception e) {
return Optional.empty();
}
}
/**
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
*
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
*
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
*/
private long countNonHeaderLines(Path csv) throws IOException {
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
}
}
}