추론 실행 추가
This commit is contained in:
@@ -2,8 +2,10 @@ package com.kamco.cd.training;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
|
||||
@EnableAsync
|
||||
@SpringBootApplication
|
||||
@EnableScheduling
|
||||
public class KamcoTrainingApplication {
|
||||
|
||||
@@ -40,6 +40,7 @@ public class ModelTrainMngDto {
|
||||
private String statusCd;
|
||||
private String trainType;
|
||||
private String modelNo;
|
||||
private Long currentAttemptId;
|
||||
|
||||
public String getStatusName() {
|
||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -27,6 +28,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
@Slf4j
|
||||
public class ModelTrainMngService {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final HyperParamCoreService hyperParamCoreService;
|
||||
|
||||
@@ -62,8 +64,8 @@ public class ModelTrainMngService {
|
||||
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||
|
||||
// 전이 학습은 모델 선택 필수
|
||||
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
|
||||
if (req.getBeforeModelId() != null) {
|
||||
if (TrainType.TRANSFER.getId().equals(req.getTrainType())) {
|
||||
if (req.getBeforeModelId() == null) {
|
||||
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
|
||||
}
|
||||
}
|
||||
@@ -87,6 +89,11 @@ public class ModelTrainMngService {
|
||||
|
||||
// 모델 config 저장
|
||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||
|
||||
// 저장 다 끝난 뒤에 job enqueue
|
||||
if (Boolean.TRUE.equals(req.getIsStart())) {
|
||||
trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class ModelTrainJobCoreService {
|
||||
|
||||
private final ModelTrainJobRepository modelTrainJobRepository;
|
||||
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId);
|
||||
}
|
||||
|
||||
/** QUEUED Job 생성 */
|
||||
@Transactional
|
||||
public Long createQueuedJob(
|
||||
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
|
||||
|
||||
ModelTrainJobEntity job = new ModelTrainJobEntity();
|
||||
job.setModelId(modelId);
|
||||
job.setAttemptNo(attemptNo);
|
||||
job.setStatusCd("QUEUED");
|
||||
job.setParamsJson(paramsJson);
|
||||
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
|
||||
|
||||
modelTrainJobRepository.save(job);
|
||||
return job.getId();
|
||||
}
|
||||
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("RUNNING");
|
||||
job.setContainerName(containerName);
|
||||
job.setLogPath(logPath);
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
@Transactional
|
||||
public void markSuccess(Long jobId, int exitCode) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("SUCCESS");
|
||||
job.setExitCode(exitCode);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 실패 처리 */
|
||||
@Transactional
|
||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("FAILED");
|
||||
job.setExitCode(exitCode);
|
||||
job.setErrorMessage(errorMessage);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 취소 처리 */
|
||||
@Transactional
|
||||
public void markCanceled(Long jobId) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("CANCELED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
@@ -30,6 +31,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@@ -213,6 +215,20 @@ public class ModelTrainMngCoreService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* uuid로 model id 조회
|
||||
*
|
||||
* @param uuid
|
||||
* @return
|
||||
*/
|
||||
public Long findModelIdByUuid(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델학습 아이디로 config정보 조회
|
||||
*
|
||||
@@ -245,4 +261,101 @@ public class ModelTrainMngCoreService {
|
||||
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
|
||||
return datasetRepository.getDatasetSelectG2G3List(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델관리 조회
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(id)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||
@Transactional
|
||||
public void markInProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
master.setCurrentAttemptId(jobId);
|
||||
|
||||
// 필요하면 시작시간도 여기서 찍어줌
|
||||
}
|
||||
|
||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||
@Transactional
|
||||
public void clearLastError(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||
@Transactional
|
||||
public void markStopped(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||
}
|
||||
|
||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||
@Transactional
|
||||
public void markCompleted(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void markSuccess(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
// 모델 상태 완료 처리
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
|
||||
// (선택) 마지막 에러 메시지 비우기
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 실행에 필요한 파라미터 조회
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return modelMngRepository.findTrainRunRequest(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,12 @@ public class ModelMasterEntity {
|
||||
@Column(name = "step2_metric_save_yn")
|
||||
private Boolean step2MetricSaveYn;
|
||||
|
||||
@Column(name = "current_attempt_id")
|
||||
private Long currentAttemptId;
|
||||
|
||||
@Column(name = "last_error")
|
||||
private String lastError;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
@@ -111,6 +117,7 @@ public class ModelMasterEntity {
|
||||
this.step2State,
|
||||
this.statusCd,
|
||||
this.trainType,
|
||||
this.modelNo);
|
||||
this.modelNo,
|
||||
this.currentAttemptId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.hibernate.annotations.ColumnDefault;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_train_job")
|
||||
public class ModelTrainJobEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "id", nullable = false)
|
||||
private Long id;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "model_id", nullable = false)
|
||||
private Long modelId;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "attempt_no", nullable = false)
|
||||
private Integer attemptNo;
|
||||
|
||||
@Size(max = 30)
|
||||
@NotNull
|
||||
@Column(name = "status_cd", nullable = false, length = 30)
|
||||
private String statusCd;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "params_json", nullable = false)
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
private Map<String, Object> paramsJson;
|
||||
|
||||
@Size(max = 200)
|
||||
@Column(name = "container_name", length = 200)
|
||||
private String containerName;
|
||||
|
||||
@Size(max = 500)
|
||||
@Column(name = "log_path", length = 500)
|
||||
private String logPath;
|
||||
|
||||
@Column(name = "exit_code")
|
||||
private Integer exitCode;
|
||||
|
||||
@Size(max = 2000)
|
||||
@Column(name = "error_message", length = 2000)
|
||||
private String errorMessage;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "queued_dttm")
|
||||
private ZonedDateTime queuedDttm;
|
||||
|
||||
@Column(name = "started_dttm")
|
||||
private ZonedDateTime startedDttm;
|
||||
|
||||
@Column(name = "finished_dttm")
|
||||
private ZonedDateTime finishedDttm;
|
||||
|
||||
@Column(name = "locked_dttm")
|
||||
private ZonedDateTime lockedDttm;
|
||||
|
||||
@Size(max = 100)
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Page;
|
||||
@@ -19,4 +20,6 @@ public interface ModelMngRepositoryCustom {
|
||||
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
||||
|
||||
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
|
||||
|
||||
TrainRunRequest findTrainRunRequest(Long modelId);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.Expressions;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -82,4 +87,60 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
TrainRunRequest.class,
|
||||
modelMasterEntity.uuid, // datasetFolder
|
||||
modelMasterEntity.uuid, // outputFolder
|
||||
modelHyperParamEntity.inputSize,
|
||||
modelHyperParamEntity.cropSize,
|
||||
modelHyperParamEntity.batchSize,
|
||||
modelHyperParamEntity.gpuIds,
|
||||
modelHyperParamEntity.gpuCnt,
|
||||
modelHyperParamEntity.learningRate,
|
||||
modelHyperParamEntity.backbone,
|
||||
modelHyperParamEntity.epochCnt,
|
||||
modelHyperParamEntity.trainNumWorkers,
|
||||
modelHyperParamEntity.valNumWorkers,
|
||||
modelHyperParamEntity.testNumWorkers,
|
||||
modelHyperParamEntity.trainShuffle,
|
||||
modelHyperParamEntity.trainPersistent,
|
||||
modelHyperParamEntity.valPersistent,
|
||||
modelHyperParamEntity.dropPathRate,
|
||||
modelHyperParamEntity.frozenStages,
|
||||
modelHyperParamEntity.neckPolicy,
|
||||
modelHyperParamEntity.classWeight,
|
||||
modelHyperParamEntity.decoderChannels,
|
||||
modelHyperParamEntity.weightDecay,
|
||||
modelHyperParamEntity.layerDecayRate,
|
||||
modelHyperParamEntity.ignoreIndex,
|
||||
modelHyperParamEntity.ddpFindUnusedParams,
|
||||
modelHyperParamEntity.numLayers,
|
||||
modelHyperParamEntity.metrics,
|
||||
modelHyperParamEntity.saveBest,
|
||||
modelHyperParamEntity.saveBestRule,
|
||||
modelHyperParamEntity.valInterval,
|
||||
modelHyperParamEntity.logInterval,
|
||||
modelHyperParamEntity.visInterval,
|
||||
modelHyperParamEntity.rotProb,
|
||||
modelHyperParamEntity.rotDegree,
|
||||
modelHyperParamEntity.flipProb,
|
||||
modelHyperParamEntity.exchangeProb,
|
||||
modelHyperParamEntity.brightnessDelta,
|
||||
modelHyperParamEntity.contrastRange,
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class)))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
.leftJoin(modelConfigEntity)
|
||||
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
|
||||
.fetchOne();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
public interface ModelTrainJobRepository
|
||||
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import java.util.Optional;
|
||||
|
||||
public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import jakarta.persistence.EntityManager;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||
|
||||
private final EntityManager em;
|
||||
|
||||
private JPAQueryFactory queryFactory() {
|
||||
return new JPAQueryFactory(em);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
|
||||
@RequiredArgsConstructor
|
||||
@RestController
|
||||
@RequestMapping("/api/train")
|
||||
public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "실행 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@RequestMapping("/run/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.enqueue(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
|
||||
public class ModelTrainJobQueuedEvent {
|
||||
|
||||
private final Long jobId;
|
||||
|
||||
public ModelTrainJobQueuedEvent(Long jobId) {
|
||||
this.jobId = jobId;
|
||||
}
|
||||
|
||||
public Long getJobId() {
|
||||
return jobId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunRequest {
|
||||
|
||||
// ========================
|
||||
// 기본
|
||||
// ========================
|
||||
private String datasetFolder;
|
||||
private String outputFolder;
|
||||
private String inputSize;
|
||||
private String cropSize;
|
||||
private Integer batchSize;
|
||||
private String gpuIds;
|
||||
private Integer gpus;
|
||||
private Double learningRate;
|
||||
private String backbone;
|
||||
private Integer epochs;
|
||||
|
||||
// ========================
|
||||
// Data
|
||||
// ========================
|
||||
private Integer trainNumWorkers;
|
||||
private Integer valNumWorkers;
|
||||
private Integer testNumWorkers;
|
||||
private Boolean trainShuffle;
|
||||
private Boolean trainPersistent;
|
||||
private Boolean valPersistent;
|
||||
|
||||
// ========================
|
||||
// Model Architecture
|
||||
// ========================
|
||||
private Double dropPathRate;
|
||||
private Integer frozenStages;
|
||||
private String neckPolicy;
|
||||
private String classWeight;
|
||||
private String decoderChannels;
|
||||
|
||||
// ========================
|
||||
// Loss & Optimization
|
||||
// ========================
|
||||
private Double weightDecay;
|
||||
private Double layerDecayRate;
|
||||
private Integer ignoreIndex;
|
||||
private Boolean ddpFindUnusedParams;
|
||||
private Integer numLayers;
|
||||
|
||||
// ========================
|
||||
// Evaluation
|
||||
// ========================
|
||||
private String metrics;
|
||||
private String saveBest;
|
||||
private String saveBestRule;
|
||||
private Integer valInterval;
|
||||
private Integer logInterval;
|
||||
private Integer visInterval;
|
||||
|
||||
// ========================
|
||||
// Augmentation
|
||||
// ========================
|
||||
private Double rotProb;
|
||||
private String rotDegree;
|
||||
private Double flipProb;
|
||||
private Double exchangeProb;
|
||||
private Integer brightnessDelta;
|
||||
private String contrastRange;
|
||||
private String saturationRange;
|
||||
private Integer hueDelta;
|
||||
|
||||
// ========================
|
||||
// 실행 타임아웃
|
||||
// ========================
|
||||
private Integer timeoutSeconds;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
/** 학습 실행 결과 반환 객체 */
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunResult {
|
||||
|
||||
private String jobId;
|
||||
private String containerName;
|
||||
private int exitCode;
|
||||
private String status;
|
||||
private String logs;
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@Value("${train.docker.image}")
|
||||
private String image;
|
||||
|
||||
// 학습 요청 데이터가 위치한 호스트 디렉토리
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
// 컨테이너 이름 prefix
|
||||
@Value("${train.docker.containerPrefix}")
|
||||
private String containerPrefix;
|
||||
|
||||
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
|
||||
@Value("${train.docker.shmSize:16g}")
|
||||
private String shmSize;
|
||||
|
||||
// IPC host 사용 여부
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
|
||||
|
||||
// 실행 식별용 jobId 생성
|
||||
String jobId = UUID.randomUUID().toString().substring(0, 8);
|
||||
|
||||
// 컨테이너 이름 생성 (중복 방지 목적)
|
||||
String containerName = containerPrefix + "-" + jobId;
|
||||
|
||||
// docker run 명령어 조립
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
|
||||
// 프로세스 실행
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
|
||||
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
// 실행 로그 수집
|
||||
StringBuilder log = new StringBuilder();
|
||||
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
log.append(line).append('\n');
|
||||
}
|
||||
}
|
||||
|
||||
// 지정된 timeout 내에 종료 대기
|
||||
int timeoutSeconds = 7200; // 기본 2시간
|
||||
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
// 타임아웃 발생 시 컨테이너 강제 제거
|
||||
killContainer(containerName);
|
||||
|
||||
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
|
||||
}
|
||||
|
||||
// 종료 코드 확인 (0=정상)
|
||||
int exit = p.exitValue();
|
||||
|
||||
return new TrainRunResult(
|
||||
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* docker run 명령어 리스트 구성 - 환경변수 설정 - GPU 옵션 설정 - 볼륨 마운트 - 컨테이너 내부 python 실행 명령 구성 - 요청값이
|
||||
* null/blank면 해당 옵션은 "아예 생략"
|
||||
*/
|
||||
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName);
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
|
||||
// 환경변수 설정
|
||||
c.add("-e");
|
||||
c.add("OPENCV_LOG_LEVEL=ERROR");
|
||||
c.add("-e");
|
||||
c.add("NCCL_DEBUG=INFO");
|
||||
c.add("-e");
|
||||
c.add("NCCL_IB_DISABLE=1");
|
||||
c.add("-e");
|
||||
c.add("NCCL_P2P_DISABLE=0");
|
||||
c.add("-e");
|
||||
c.add("NCCL_SOCKET_IFNAME=eth0");
|
||||
|
||||
// GPU 전체 사용
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
|
||||
// IPC host 사용 여부
|
||||
if (ipcHost) {
|
||||
c.add("--ipc=host");
|
||||
}
|
||||
|
||||
// 공유메모리 설정
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
// 메모리 관련 ulimit 설정
|
||||
c.add("--ulimit");
|
||||
c.add("memlock=-1");
|
||||
c.add("--ulimit");
|
||||
c.add("stack=67108864");
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
// 표준입력 유지 (-it 대신 -i만 사용)
|
||||
c.add("-i");
|
||||
|
||||
// 사용할 이미지
|
||||
c.add(image);
|
||||
|
||||
// ===== 컨테이너 내부 실행 명령 =====
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||
|
||||
// ===== 기본 파라미터 =====
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
addArg(c, "--batch-size", req.getBatchSize());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
||||
// addArg(c, "--gpus", req.getGpus());
|
||||
addArg(c, "--lr", req.getLearningRate());
|
||||
addArg(c, "--backbone", req.getBackbone());
|
||||
addArg(c, "--epochs", req.getEpochs());
|
||||
|
||||
// ===== Data =====
|
||||
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
|
||||
addArg(c, "--val-num-workers", req.getValNumWorkers());
|
||||
addArg(c, "--test-num-workers", req.getTestNumWorkers());
|
||||
addArg(c, "--train-shuffle", req.getTrainShuffle());
|
||||
addArg(c, "--train-persistent", req.getTrainPersistent());
|
||||
addArg(c, "--val-persistent", req.getValPersistent());
|
||||
|
||||
// ===== Model Architecture =====
|
||||
addArg(c, "--drop-path-rate", req.getDropPathRate());
|
||||
addArg(c, "--frozen-stages", req.getFrozenStages());
|
||||
addArg(c, "--neck-policy", req.getNeckPolicy());
|
||||
addArg(c, "--class-weight", req.getClassWeight());
|
||||
addArg(c, "--decoder-channels", req.getDecoderChannels());
|
||||
|
||||
// ===== Loss & Optimization =====
|
||||
addArg(c, "--weight-decay", req.getWeightDecay());
|
||||
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
|
||||
addArg(c, "--ignore-index", req.getIgnoreIndex());
|
||||
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
|
||||
addArg(c, "--num-layers", req.getNumLayers());
|
||||
|
||||
// ===== Evaluation =====
|
||||
addArg(c, "--metrics", req.getMetrics());
|
||||
addArg(c, "--save-best", req.getSaveBest());
|
||||
addArg(c, "--save-best-rule", req.getSaveBestRule());
|
||||
addArg(c, "--val-interval", req.getValInterval());
|
||||
addArg(c, "--log-interval", req.getLogInterval());
|
||||
addArg(c, "--vis-interval", req.getVisInterval());
|
||||
|
||||
// ===== Augmentation =====
|
||||
addArg(c, "--rot-prob", req.getRotProb());
|
||||
addArg(c, "--rot-degree", req.getRotDegree());
|
||||
addArg(c, "--flip-prob", req.getFlipProb());
|
||||
addArg(c, "--exchange-prob", req.getExchangeProb());
|
||||
addArg(c, "--brightness-delta", req.getBrightnessDelta());
|
||||
addArg(c, "--contrast-range", req.getContrastRange());
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
|
||||
private void addArg(List<String> c, String key, Object value) {
|
||||
if (value == null) return;
|
||||
String s = String.valueOf(value).trim();
|
||||
if (s.isEmpty()) return;
|
||||
c.add(key);
|
||||
c.add(s);
|
||||
}
|
||||
|
||||
/** 컨테이너 강제 종료 및 제거 */
|
||||
private void killContainer(String containerName) {
|
||||
try {
|
||||
new ProcessBuilder("docker", "rm", "-f", containerName)
|
||||
.redirectErrorStream(true)
|
||||
.start()
|
||||
.waitFor(10, TimeUnit.SECONDS);
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class TrainJobService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
public Long getModelIdByUuid(UUID uuid) {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
/** 실행 예약 (QUEUE 등록) */
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||
|
||||
if (trainRunRequest == null) {
|
||||
throw new IllegalArgumentException("Model not found: " + modelId);
|
||||
}
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 재시작 버튼
|
||||
*
|
||||
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
||||
*/
|
||||
@Transactional
|
||||
public Long restart(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
|
||||
throw new IllegalStateException("이미 진행중입니다.");
|
||||
}
|
||||
|
||||
var lastJob =
|
||||
modelTrainJobCoreService
|
||||
.findLatestByModelId(modelId)
|
||||
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId,
|
||||
nextAttemptNo,
|
||||
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
|
||||
ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 버튼
|
||||
*
|
||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||
*
|
||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long attemptId = master.getCurrentAttemptId();
|
||||
if (attemptId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
modelTrainJobCoreService.markCanceled(attemptId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class TrainJobWorker {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Async
|
||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||
public void handle(ModelTrainJobQueuedEvent event) {
|
||||
|
||||
Long jobId = event.getJobId(); // record면 event.jobId()
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
Long modelId = job.getModelId();
|
||||
|
||||
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
|
||||
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
|
||||
// req가 null이면 실패 처리
|
||||
if (req == null) {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
|
||||
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
|
||||
return;
|
||||
}
|
||||
|
||||
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
|
||||
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
|
||||
|
||||
// logPath/lockedBy는 너 환경에 맞게
|
||||
String logPath = null;
|
||||
String lockedBy = "TRAIN_WORKER";
|
||||
|
||||
// RUNNING 표시
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
|
||||
|
||||
try {
|
||||
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
|
||||
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
|
||||
// 아래 "추천 수정" 참고.
|
||||
TrainRunResult result = dockerTrainService.runTrainSync(req);
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
|
||||
} else {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null || paramsJson.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
|
||||
}
|
||||
}
|
||||
@@ -57,3 +57,12 @@ file:
|
||||
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:love_latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
ipcHost: true
|
||||
|
||||
@@ -43,3 +43,12 @@ file:
|
||||
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:love_latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
ipcHost: true
|
||||
|
||||
Reference in New Issue
Block a user