추론 실행 추가
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class ModelTrainJobCoreService {
|
||||
|
||||
private final ModelTrainJobRepository modelTrainJobRepository;
|
||||
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId);
|
||||
}
|
||||
|
||||
/** QUEUED Job 생성 */
|
||||
@Transactional
|
||||
public Long createQueuedJob(
|
||||
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
|
||||
|
||||
ModelTrainJobEntity job = new ModelTrainJobEntity();
|
||||
job.setModelId(modelId);
|
||||
job.setAttemptNo(attemptNo);
|
||||
job.setStatusCd("QUEUED");
|
||||
job.setParamsJson(paramsJson);
|
||||
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
|
||||
|
||||
modelTrainJobRepository.save(job);
|
||||
return job.getId();
|
||||
}
|
||||
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("RUNNING");
|
||||
job.setContainerName(containerName);
|
||||
job.setLogPath(logPath);
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
@Transactional
|
||||
public void markSuccess(Long jobId, int exitCode) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("SUCCESS");
|
||||
job.setExitCode(exitCode);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 실패 처리 */
|
||||
@Transactional
|
||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("FAILED");
|
||||
job.setExitCode(exitCode);
|
||||
job.setErrorMessage(errorMessage);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 취소 처리 */
|
||||
@Transactional
|
||||
public void markCanceled(Long jobId) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("CANCELED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
@@ -30,6 +31,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@@ -213,6 +215,20 @@ public class ModelTrainMngCoreService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* uuid로 model id 조회
|
||||
*
|
||||
* @param uuid
|
||||
* @return
|
||||
*/
|
||||
public Long findModelIdByUuid(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델학습 아이디로 config정보 조회
|
||||
*
|
||||
@@ -245,4 +261,101 @@ public class ModelTrainMngCoreService {
|
||||
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
|
||||
return datasetRepository.getDatasetSelectG2G3List(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델관리 조회
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(id)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||
@Transactional
|
||||
public void markInProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
master.setCurrentAttemptId(jobId);
|
||||
|
||||
// 필요하면 시작시간도 여기서 찍어줌
|
||||
}
|
||||
|
||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||
@Transactional
|
||||
public void clearLastError(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||
@Transactional
|
||||
public void markStopped(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||
}
|
||||
|
||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||
@Transactional
|
||||
public void markCompleted(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void markSuccess(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
// 모델 상태 완료 처리
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
|
||||
// (선택) 마지막 에러 메시지 비우기
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 실행에 필요한 파라미터 조회
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return modelMngRepository.findTrainRunRequest(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,12 @@ public class ModelMasterEntity {
|
||||
@Column(name = "step2_metric_save_yn")
|
||||
private Boolean step2MetricSaveYn;
|
||||
|
||||
@Column(name = "current_attempt_id")
|
||||
private Long currentAttemptId;
|
||||
|
||||
@Column(name = "last_error")
|
||||
private String lastError;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
@@ -111,6 +117,7 @@ public class ModelMasterEntity {
|
||||
this.step2State,
|
||||
this.statusCd,
|
||||
this.trainType,
|
||||
this.modelNo);
|
||||
this.modelNo,
|
||||
this.currentAttemptId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.hibernate.annotations.ColumnDefault;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_train_job")
|
||||
public class ModelTrainJobEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "id", nullable = false)
|
||||
private Long id;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "model_id", nullable = false)
|
||||
private Long modelId;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "attempt_no", nullable = false)
|
||||
private Integer attemptNo;
|
||||
|
||||
@Size(max = 30)
|
||||
@NotNull
|
||||
@Column(name = "status_cd", nullable = false, length = 30)
|
||||
private String statusCd;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "params_json", nullable = false)
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
private Map<String, Object> paramsJson;
|
||||
|
||||
@Size(max = 200)
|
||||
@Column(name = "container_name", length = 200)
|
||||
private String containerName;
|
||||
|
||||
@Size(max = 500)
|
||||
@Column(name = "log_path", length = 500)
|
||||
private String logPath;
|
||||
|
||||
@Column(name = "exit_code")
|
||||
private Integer exitCode;
|
||||
|
||||
@Size(max = 2000)
|
||||
@Column(name = "error_message", length = 2000)
|
||||
private String errorMessage;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "queued_dttm")
|
||||
private ZonedDateTime queuedDttm;
|
||||
|
||||
@Column(name = "started_dttm")
|
||||
private ZonedDateTime startedDttm;
|
||||
|
||||
@Column(name = "finished_dttm")
|
||||
private ZonedDateTime finishedDttm;
|
||||
|
||||
@Column(name = "locked_dttm")
|
||||
private ZonedDateTime lockedDttm;
|
||||
|
||||
@Size(max = 100)
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Page;
|
||||
@@ -19,4 +20,6 @@ public interface ModelMngRepositoryCustom {
|
||||
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
||||
|
||||
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
|
||||
|
||||
TrainRunRequest findTrainRunRequest(Long modelId);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.Expressions;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -82,4 +87,60 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
TrainRunRequest.class,
|
||||
modelMasterEntity.uuid, // datasetFolder
|
||||
modelMasterEntity.uuid, // outputFolder
|
||||
modelHyperParamEntity.inputSize,
|
||||
modelHyperParamEntity.cropSize,
|
||||
modelHyperParamEntity.batchSize,
|
||||
modelHyperParamEntity.gpuIds,
|
||||
modelHyperParamEntity.gpuCnt,
|
||||
modelHyperParamEntity.learningRate,
|
||||
modelHyperParamEntity.backbone,
|
||||
modelHyperParamEntity.epochCnt,
|
||||
modelHyperParamEntity.trainNumWorkers,
|
||||
modelHyperParamEntity.valNumWorkers,
|
||||
modelHyperParamEntity.testNumWorkers,
|
||||
modelHyperParamEntity.trainShuffle,
|
||||
modelHyperParamEntity.trainPersistent,
|
||||
modelHyperParamEntity.valPersistent,
|
||||
modelHyperParamEntity.dropPathRate,
|
||||
modelHyperParamEntity.frozenStages,
|
||||
modelHyperParamEntity.neckPolicy,
|
||||
modelHyperParamEntity.classWeight,
|
||||
modelHyperParamEntity.decoderChannels,
|
||||
modelHyperParamEntity.weightDecay,
|
||||
modelHyperParamEntity.layerDecayRate,
|
||||
modelHyperParamEntity.ignoreIndex,
|
||||
modelHyperParamEntity.ddpFindUnusedParams,
|
||||
modelHyperParamEntity.numLayers,
|
||||
modelHyperParamEntity.metrics,
|
||||
modelHyperParamEntity.saveBest,
|
||||
modelHyperParamEntity.saveBestRule,
|
||||
modelHyperParamEntity.valInterval,
|
||||
modelHyperParamEntity.logInterval,
|
||||
modelHyperParamEntity.visInterval,
|
||||
modelHyperParamEntity.rotProb,
|
||||
modelHyperParamEntity.rotDegree,
|
||||
modelHyperParamEntity.flipProb,
|
||||
modelHyperParamEntity.exchangeProb,
|
||||
modelHyperParamEntity.brightnessDelta,
|
||||
modelHyperParamEntity.contrastRange,
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class)))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
.leftJoin(modelConfigEntity)
|
||||
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
|
||||
.fetchOne();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
public interface ModelTrainJobRepository
|
||||
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import java.util.Optional;
|
||||
|
||||
public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import jakarta.persistence.EntityManager;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||
|
||||
private final EntityManager em;
|
||||
|
||||
private JPAQueryFactory queryFactory() {
|
||||
return new JPAQueryFactory(em);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user