추론 실행 추가

This commit is contained in:
2026-02-11 20:21:25 +09:00
parent 35767adba1
commit 1249a80da5
21 changed files with 1049 additions and 3 deletions

View File

@@ -0,0 +1,101 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class ModelTrainJobCoreService {
private final ModelTrainJobRepository modelTrainJobRepository;
public int findMaxAttemptNo(Long modelId) {
return modelTrainJobRepository.findMaxAttemptNo(modelId);
}
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId);
}
public Optional<ModelTrainJobEntity> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId);
}
/** QUEUED Job 생성 */
@Transactional
public Long createQueuedJob(
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
ModelTrainJobEntity job = new ModelTrainJobEntity();
job.setModelId(modelId);
job.setAttemptNo(attemptNo);
job.setStatusCd("QUEUED");
job.setParamsJson(paramsJson);
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
modelTrainJobRepository.save(job);
return job.getId();
}
/** 실행 시작 처리 */
@Transactional
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
job.setLogPath(logPath);
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
}
/** 성공 처리 */
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 실패 처리 */
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 취소 처리 */
@Transactional
public void markCanceled(Long jobId) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("CANCELED");
job.setFinishedDttm(ZonedDateTime.now());
}
}

View File

@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.UUID;
@@ -30,6 +31,7 @@ import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@@ -213,6 +215,20 @@ public class ModelTrainMngCoreService {
}
}
/**
* uuid로 model id 조회
*
* @param uuid
* @return
*/
public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId();
}
/**
* 모델학습 아이디로 config정보 조회
*
@@ -245,4 +261,101 @@ public class ModelTrainMngCoreService {
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 모델관리 조회
*
* @param id
* @return
*/
public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity =
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto();
}
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
@Transactional
public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId);
// 필요하면 시작시간도 여기서 찍어줌
}
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
@Transactional
public void clearLastError(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null);
}
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
@Transactional
public void markStopped(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId());
}
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
@Transactional
public void markCompleted(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
}
@Transactional
public void markSuccess(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId());
// (선택) 마지막 에러 메시지 비우기
master.setLastError(null);
}
/**
* 학습 실행에 필요한 파라미터 조회
*
* @param modelId
* @return
*/
public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId);
}
}

View File

@@ -97,6 +97,12 @@ public class ModelMasterEntity {
@Column(name = "step2_metric_save_yn")
private Boolean step2MetricSaveYn;
@Column(name = "current_attempt_id")
private Long currentAttemptId;
@Column(name = "last_error")
private String lastError;
public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic(
this.id,
@@ -111,6 +117,7 @@ public class ModelMasterEntity {
this.step2State,
this.statusCd,
this.trainType,
this.modelNo);
this.modelNo,
this.currentAttemptId);
}
}

View File

@@ -0,0 +1,79 @@
package com.kamco.cd.training.postgres.entity;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
@Getter
@Setter
@Entity
@Table(name = "tb_model_train_job")
public class ModelTrainJobEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@NotNull
@Column(name = "attempt_no", nullable = false)
private Integer attemptNo;
@Size(max = 30)
@NotNull
@Column(name = "status_cd", nullable = false, length = 30)
private String statusCd;
@NotNull
@Column(name = "params_json", nullable = false)
@JdbcTypeCode(SqlTypes.JSON)
private Map<String, Object> paramsJson;
@Size(max = 200)
@Column(name = "container_name", length = 200)
private String containerName;
@Size(max = 500)
@Column(name = "log_path", length = 500)
private String logPath;
@Column(name = "exit_code")
private Integer exitCode;
@Size(max = 2000)
@Column(name = "error_message", length = 2000)
private String errorMessage;
@ColumnDefault("now()")
@Column(name = "queued_dttm")
private ZonedDateTime queuedDttm;
@Column(name = "started_dttm")
private ZonedDateTime startedDttm;
@Column(name = "finished_dttm")
private ZonedDateTime finishedDttm;
@Column(name = "locked_dttm")
private ZonedDateTime lockedDttm;
@Size(max = 100)
@Column(name = "locked_by", length = 100)
private String lockedBy;
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
@@ -19,4 +20,6 @@ public interface ModelMngRepositoryCustom {
Optional<ModelMasterEntity> findByUuid(UUID uuid);
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
TrainRunRequest findTrainRunRequest(Long modelId);
}

View File

@@ -1,10 +1,15 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import java.util.Optional;
@@ -82,4 +87,60 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
return Optional.empty();
}
@Override
public TrainRunRequest findTrainRunRequest(Long modelId) {
queryFactory
.select(
Projections.constructor(
TrainRunRequest.class,
modelMasterEntity.uuid, // datasetFolder
modelMasterEntity.uuid, // outputFolder
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
modelHyperParamEntity.gpuIds,
modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone,
modelHyperParamEntity.epochCnt,
modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers,
modelHyperParamEntity.trainShuffle,
modelHyperParamEntity.trainPersistent,
modelHyperParamEntity.valPersistent,
modelHyperParamEntity.dropPathRate,
modelHyperParamEntity.frozenStages,
modelHyperParamEntity.neckPolicy,
modelHyperParamEntity.classWeight,
modelHyperParamEntity.decoderChannels,
modelHyperParamEntity.weightDecay,
modelHyperParamEntity.layerDecayRate,
modelHyperParamEntity.ignoreIndex,
modelHyperParamEntity.ddpFindUnusedParams,
modelHyperParamEntity.numLayers,
modelHyperParamEntity.metrics,
modelHyperParamEntity.saveBest,
modelHyperParamEntity.saveBestRule,
modelHyperParamEntity.valInterval,
modelHyperParamEntity.logInterval,
modelHyperParamEntity.visInterval,
modelHyperParamEntity.rotProb,
modelHyperParamEntity.rotDegree,
modelHyperParamEntity.flipProb,
modelHyperParamEntity.exchangeProb,
modelHyperParamEntity.brightnessDelta,
modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class)))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(modelConfigEntity)
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
.fetchOne();
return null;
}
}

View File

@@ -0,0 +1,7 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import org.springframework.data.jpa.repository.JpaRepository;
public interface ModelTrainJobRepository
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}

View File

@@ -0,0 +1,12 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.Optional;
public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
}

View File

@@ -0,0 +1,34 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final EntityManager em;
private JPAQueryFactory queryFactory() {
return new JPAQueryFactory(em);
}
@Override
public int findMaxAttemptNo(Long modelId) {
return 0;
}
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return Optional.empty();
}
@Override
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
return Optional.empty();
}
}