테스트 실행 추가
This commit is contained in:
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@@ -20,12 +21,12 @@ public class ModelTrainJobCoreService {
|
||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId);
|
||||
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId);
|
||||
public Optional<ModelTrainJobDto> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
/** QUEUED Job 생성 */
|
||||
@@ -95,7 +96,7 @@ public class ModelTrainJobCoreService {
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("CANCELED");
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainMngCoreService {
|
||||
|
||||
private final ModelMngRepository modelMngRepository;
|
||||
private final ModelDatasetRepository modelDatasetRepository;
|
||||
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||
@@ -323,7 +324,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
@@ -332,7 +333,25 @@ public class ModelTrainMngCoreService {
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
@@ -358,4 +377,58 @@ public class ModelTrainMngCoreService {
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return modelMngRepository.findTrainRunRequest(modelId);
|
||||
}
|
||||
|
||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep1Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep2Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
@@ -76,4 +77,19 @@ public class ModelTrainJobEntity {
|
||||
@Size(max = 100)
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
this.modelId,
|
||||
this.attemptNo,
|
||||
this.statusCd,
|
||||
this.exitCode,
|
||||
this.errorMessage,
|
||||
this.containerName,
|
||||
this.paramsJson,
|
||||
this.queuedDttm,
|
||||
this.startedDttm,
|
||||
this.finishedDttm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
modelHyperParamEntity.contrastRange,
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class)))
|
||||
Expressions.nullExpression(Integer.class),
|
||||
Expressions.nullExpression(String.class)))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
|
||||
@@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
|
||||
}
|
||||
|
||||
@@ -1,34 +1,43 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import jakarta.persistence.EntityManager;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||
|
||||
private final EntityManager em;
|
||||
private final JPAQueryFactory queryFactory;
|
||||
|
||||
private JPAQueryFactory queryFactory() {
|
||||
return new JPAQueryFactory(em);
|
||||
public ModelTrainJobRepositoryImpl(EntityManager em) {
|
||||
this.queryFactory = new JPAQueryFactory(em);
|
||||
}
|
||||
|
||||
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return 0;
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
Integer max =
|
||||
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||
|
||||
return max != null ? max : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
|
||||
* 단순.
|
||||
*/
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return Optional.empty();
|
||||
}
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
|
||||
return Optional.empty();
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user