추론 실행 추가

This commit is contained in:
2026-02-11 20:21:25 +09:00
parent 35767adba1
commit 1249a80da5
21 changed files with 1049 additions and 3 deletions

View File

@@ -2,8 +2,10 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication
@EnableScheduling
public class KamcoTrainingApplication {

View File

@@ -40,6 +40,7 @@ public class ModelTrainMngDto {
private String statusCd;
private String trainType;
private String modelNo;
private Long currentAttemptId;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;

View File

@@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.service.TrainJobService;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -27,6 +28,7 @@ import org.springframework.transaction.annotation.Transactional;
@Slf4j
public class ModelTrainMngService {
private final TrainJobService trainJobService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final HyperParamCoreService hyperParamCoreService;
@@ -62,8 +64,8 @@ public class ModelTrainMngService {
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
// 전이 학습은 모델 선택 필수
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
if (req.getBeforeModelId() != null) {
if (TrainType.TRANSFER.getId().equals(req.getTrainType())) {
if (req.getBeforeModelId() == null) {
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
}
}
@@ -87,6 +89,11 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
// 저장 다 끝난 뒤에 job enqueue
if (Boolean.TRUE.equals(req.getIsStart())) {
trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT)
}
}
/**

View File

@@ -0,0 +1,101 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class ModelTrainJobCoreService {
private final ModelTrainJobRepository modelTrainJobRepository;
public int findMaxAttemptNo(Long modelId) {
return modelTrainJobRepository.findMaxAttemptNo(modelId);
}
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId);
}
public Optional<ModelTrainJobEntity> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId);
}
/** QUEUED Job 생성 */
@Transactional
public Long createQueuedJob(
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
ModelTrainJobEntity job = new ModelTrainJobEntity();
job.setModelId(modelId);
job.setAttemptNo(attemptNo);
job.setStatusCd("QUEUED");
job.setParamsJson(paramsJson);
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
modelTrainJobRepository.save(job);
return job.getId();
}
/** 실행 시작 처리 */
@Transactional
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
job.setLogPath(logPath);
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
}
/** 성공 처리 */
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 실패 처리 */
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 취소 처리 */
@Transactional
public void markCanceled(Long jobId) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("CANCELED");
job.setFinishedDttm(ZonedDateTime.now());
}
}

View File

@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID;
@@ -30,6 +31,7 @@ import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@@ -213,6 +215,20 @@ public class ModelTrainMngCoreService {
}
}
/**
* uuid로 model id 조회
*
* @param uuid
* @return
*/
public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId();
}
/**
* 모델학습 아이디로 config정보 조회
*
@@ -245,4 +261,101 @@ public class ModelTrainMngCoreService {
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 모델관리 조회
*
* @param id
* @return
*/
public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity =
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto();
}
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
@Transactional
public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId);
// 필요하면 시작시간도 여기서 찍어줌
}
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
@Transactional
public void clearLastError(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null);
}
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
@Transactional
public void markStopped(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId());
}
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
@Transactional
public void markCompleted(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
}
@Transactional
public void markSuccess(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId());
// (선택) 마지막 에러 메시지 비우기
master.setLastError(null);
}
/**
* 학습 실행에 필요한 파라미터 조회
*
* @param modelId
* @return
*/
public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId);
}
}

View File

@@ -97,6 +97,12 @@ public class ModelMasterEntity {
@Column(name = "step2_metric_save_yn")
private Boolean step2MetricSaveYn;
@Column(name = "current_attempt_id")
private Long currentAttemptId;
@Column(name = "last_error")
private String lastError;
public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic(
this.id,
@@ -111,6 +117,7 @@ public class ModelMasterEntity {
this.step2State,
this.statusCd,
this.trainType,
this.modelNo);
this.modelNo,
this.currentAttemptId);
}
}

View File

@@ -0,0 +1,79 @@
package com.kamco.cd.training.postgres.entity;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
@Getter
@Setter
@Entity
@Table(name = "tb_model_train_job")
public class ModelTrainJobEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@NotNull
@Column(name = "attempt_no", nullable = false)
private Integer attemptNo;
@Size(max = 30)
@NotNull
@Column(name = "status_cd", nullable = false, length = 30)
private String statusCd;
@NotNull
@Column(name = "params_json", nullable = false)
@JdbcTypeCode(SqlTypes.JSON)
private Map<String, Object> paramsJson;
@Size(max = 200)
@Column(name = "container_name", length = 200)
private String containerName;
@Size(max = 500)
@Column(name = "log_path", length = 500)
private String logPath;
@Column(name = "exit_code")
private Integer exitCode;
@Size(max = 2000)
@Column(name = "error_message", length = 2000)
private String errorMessage;
@ColumnDefault("now()")
@Column(name = "queued_dttm")
private ZonedDateTime queuedDttm;
@Column(name = "started_dttm")
private ZonedDateTime startedDttm;
@Column(name = "finished_dttm")
private ZonedDateTime finishedDttm;
@Column(name = "locked_dttm")
private ZonedDateTime lockedDttm;
@Size(max = 100)
@Column(name = "locked_by", length = 100)
private String lockedBy;
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
@@ -19,4 +20,6 @@ public interface ModelMngRepositoryCustom {
Optional<ModelMasterEntity> findByUuid(UUID uuid);
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
TrainRunRequest findTrainRunRequest(Long modelId);
}

View File

@@ -1,10 +1,15 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import java.util.Optional;
@@ -82,4 +87,60 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
return Optional.empty();
}
@Override
public TrainRunRequest findTrainRunRequest(Long modelId) {
queryFactory
.select(
Projections.constructor(
TrainRunRequest.class,
modelMasterEntity.uuid, // datasetFolder
modelMasterEntity.uuid, // outputFolder
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
modelHyperParamEntity.gpuIds,
modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone,
modelHyperParamEntity.epochCnt,
modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers,
modelHyperParamEntity.trainShuffle,
modelHyperParamEntity.trainPersistent,
modelHyperParamEntity.valPersistent,
modelHyperParamEntity.dropPathRate,
modelHyperParamEntity.frozenStages,
modelHyperParamEntity.neckPolicy,
modelHyperParamEntity.classWeight,
modelHyperParamEntity.decoderChannels,
modelHyperParamEntity.weightDecay,
modelHyperParamEntity.layerDecayRate,
modelHyperParamEntity.ignoreIndex,
modelHyperParamEntity.ddpFindUnusedParams,
modelHyperParamEntity.numLayers,
modelHyperParamEntity.metrics,
modelHyperParamEntity.saveBest,
modelHyperParamEntity.saveBestRule,
modelHyperParamEntity.valInterval,
modelHyperParamEntity.logInterval,
modelHyperParamEntity.visInterval,
modelHyperParamEntity.rotProb,
modelHyperParamEntity.rotDegree,
modelHyperParamEntity.flipProb,
modelHyperParamEntity.exchangeProb,
modelHyperParamEntity.brightnessDelta,
modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class)))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(modelConfigEntity)
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
.fetchOne();
return null;
}
}

View File

@@ -0,0 +1,7 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import org.springframework.data.jpa.repository.JpaRepository;
public interface ModelTrainJobRepository
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}

View File

@@ -0,0 +1,12 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.Optional;
public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
}

View File

@@ -0,0 +1,34 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final EntityManager em;
private JPAQueryFactory queryFactory() {
return new JPAQueryFactory(em);
}
@Override
public int findMaxAttemptNo(Long modelId) {
return 0;
}
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return Optional.empty();
}
@Override
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
return Optional.empty();
}
}

View File

@@ -0,0 +1,48 @@
package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/train")
public class TrainApiController {
private final TrainJobService trainJobService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@RequestMapping("/run/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.train.dto;
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
public class ModelTrainJobQueuedEvent {
private final Long jobId;
public ModelTrainJobQueuedEvent(Long jobId) {
this.jobId = jobId;
}
public Long getJobId() {
return jobId;
}
}

View File

@@ -0,0 +1,82 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunRequest {
// ========================
// 기본
// ========================
private String datasetFolder;
private String outputFolder;
private String inputSize;
private String cropSize;
private Integer batchSize;
private String gpuIds;
private Integer gpus;
private Double learningRate;
private String backbone;
private Integer epochs;
// ========================
// Data
// ========================
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// ========================
// Model Architecture
// ========================
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String classWeight;
private String decoderChannels;
// ========================
// Loss & Optimization
// ========================
private Double weightDecay;
private Double layerDecayRate;
private Integer ignoreIndex;
private Boolean ddpFindUnusedParams;
private Integer numLayers;
// ========================
// Evaluation
// ========================
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// ========================
// Augmentation
// ========================
private Double rotProb;
private String rotDegree;
private Double flipProb;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// ========================
// 실행 타임아웃
// ========================
private Integer timeoutSeconds;
}

View File

@@ -0,0 +1,20 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
/** 학습 실행 결과 반환 객체 */
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunResult {
private String jobId;
private String containerName;
private int exitCode;
private String status;
private String logs;
}

View File

@@ -0,0 +1,230 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Service
public class DockerTrainService {
// 실행할 Docker 이미지명
@Value("${train.docker.image}")
private String image;
// 학습 요청 데이터가 위치한 호스트 디렉토리
@Value("${train.docker.requestDir}")
private String requestDir;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
// 컨테이너 이름 prefix
@Value("${train.docker.containerPrefix}")
private String containerPrefix;
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
@Value("${train.docker.shmSize:16g}")
private String shmSize;
// IPC host 사용 여부
@Value("${train.docker.ipcHost:true}")
private boolean ipcHost;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
// 실행 식별용 jobId 생성
String jobId = UUID.randomUUID().toString().substring(0, 8);
// 컨테이너 이름 생성 (중복 방지 목적)
String containerName = containerPrefix + "-" + jobId;
// docker run 명령어 조립
List<String> cmd = buildDockerRunCommand(containerName, req);
// 프로세스 실행
ProcessBuilder pb = new ProcessBuilder(cmd);
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
pb.redirectErrorStream(true);
Process p = pb.start();
// 실행 로그 수집
StringBuilder log = new StringBuilder();
try (BufferedReader br =
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
log.append(line).append('\n');
}
}
// 지정된 timeout 내에 종료 대기
int timeoutSeconds = 7200; // 기본 2시간
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
if (!finished) {
// 타임아웃 발생 시 컨테이너 강제 제거
killContainer(containerName);
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
}
// 종료 코드 확인 (0=정상)
int exit = p.exitValue();
return new TrainRunResult(
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
}
/**
* docker run 명령어 리스트 구성 - 환경변수 설정 - GPU 옵션 설정 - 볼륨 마운트 - 컨테이너 내부 python 실행 명령 구성 - 요청값이
* null/blank면 해당 옵션은 "아예 생략"
*/
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName);
// 실행 종료 시 자동 삭제
c.add("--rm");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// GPU 전체 사용
c.add("--gpus");
c.add("all");
// IPC host 사용 여부
if (ipcHost) {
c.add("--ipc=host");
}
// 공유메모리 설정
c.add("--shm-size=" + shmSize);
// 메모리 관련 ulimit 설정
c.add("--ulimit");
c.add("memlock=-1");
c.add("--ulimit");
c.add("stack=67108864");
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add(requestDir + ":/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
// 표준입력 유지 (-it 대신 -i만 사용)
c.add("-i");
// 사용할 이미지
c.add(image);
// ===== 컨테이너 내부 실행 명령 =====
c.add("python");
c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
// ===== Data =====
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
addArg(c, "--val-num-workers", req.getValNumWorkers());
addArg(c, "--test-num-workers", req.getTestNumWorkers());
addArg(c, "--train-shuffle", req.getTrainShuffle());
addArg(c, "--train-persistent", req.getTrainPersistent());
addArg(c, "--val-persistent", req.getValPersistent());
// ===== Model Architecture =====
addArg(c, "--drop-path-rate", req.getDropPathRate());
addArg(c, "--frozen-stages", req.getFrozenStages());
addArg(c, "--neck-policy", req.getNeckPolicy());
addArg(c, "--class-weight", req.getClassWeight());
addArg(c, "--decoder-channels", req.getDecoderChannels());
// ===== Loss & Optimization =====
addArg(c, "--weight-decay", req.getWeightDecay());
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
addArg(c, "--ignore-index", req.getIgnoreIndex());
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
addArg(c, "--num-layers", req.getNumLayers());
// ===== Evaluation =====
addArg(c, "--metrics", req.getMetrics());
addArg(c, "--save-best", req.getSaveBest());
addArg(c, "--save-best-rule", req.getSaveBestRule());
addArg(c, "--val-interval", req.getValInterval());
addArg(c, "--log-interval", req.getLogInterval());
addArg(c, "--vis-interval", req.getVisInterval());
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta());
return c;
}
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
private void addArg(List<String> c, String key, Object value) {
if (value == null) return;
String s = String.valueOf(value).trim();
if (s.isEmpty()) return;
c.add(key);
c.add(s);
}
/** 컨테이너 강제 종료 및 제거 */
private void killContainer(String containerName) {
try {
new ProcessBuilder("docker", "rm", "-f", containerName)
.redirectErrorStream(true)
.start()
.waitFor(10, TimeUnit.SECONDS);
} catch (Exception ignored) {
}
}
}

View File

@@ -0,0 +1,119 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
public Long getModelIdByUuid(UUID uuid) {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/** 실행 예약 (QUEUE 등록) */
@Transactional
public Long enqueue(Long modelId) {
// 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId);
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
if (trainRunRequest == null) {
throw new IllegalArgumentException("Model not found: " + modelId);
}
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
@SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 재시작 버튼
*
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
*/
@Transactional
public Long restart(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
throw new IllegalStateException("이미 진행중입니다.");
}
var lastJob =
modelTrainJobCoreService
.findLatestByModelId(modelId)
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId,
nextAttemptNo,
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 중단 버튼
*
* <p>- job 상태 CANCELED - master 상태 STOPPED
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long attemptId = master.getCurrentAttemptId();
if (attemptId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
modelTrainJobCoreService.markCanceled(attemptId);
modelTrainMngCoreService.markStopped(modelId);
}
}

View File

@@ -0,0 +1,87 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
@Component
@RequiredArgsConstructor
public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
Long jobId = event.getJobId(); // record면 event.jobId()
ModelTrainJobEntity job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
Long modelId = job.getModelId();
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
// req가 null이면 실패 처리
if (req == null) {
modelTrainJobCoreService.markFailed(
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
return;
}
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
// logPath/lockedBy는 너 환경에 맞게
String logPath = null;
String lockedBy = "TRAIN_WORKER";
// RUNNING 표시
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
try {
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
// 아래 "추천 수정" 참고.
TrainRunResult result = dockerTrainService.runTrainSync(req);
if (result.getExitCode() == 0) {
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
} else {
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
} catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
if (paramsJson == null || paramsJson.isEmpty()) {
return null;
}
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
}
}

View File

@@ -57,3 +57,12 @@ file:
dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/
train:
docker:
image: "kamco-cd-train:love_latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"
shmSize: "16g"
ipcHost: true

View File

@@ -43,3 +43,12 @@ file:
dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/
train:
docker:
image: "kamco-cd-train:love_latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"
shmSize: "16g"
ipcHost: true