테스트 실행 추가

This commit is contained in:
2026-02-11 21:58:25 +09:00
parent 1249a80da5
commit 2f8bd1f98c
14 changed files with 670 additions and 98 deletions

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Optional;
@@ -20,12 +21,12 @@ public class ModelTrainJobCoreService {
return modelTrainJobRepository.findMaxAttemptNo(modelId);
}
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId);
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
}
public Optional<ModelTrainJobEntity> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId);
public Optional<ModelTrainJobDto> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
}
/** QUEUED Job 생성 */
@@ -95,7 +96,7 @@ public class ModelTrainJobCoreService {
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("CANCELED");
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
}
}

View File

@@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTrainMngCoreService {
private final ModelMngRepository modelMngRepository;
private final ModelDatasetRepository modelDatasetRepository;
private final ModelDatasetMappRepository modelDatasetMapRepository;
@@ -323,7 +324,7 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
@@ -332,7 +333,25 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
@Transactional
@@ -358,4 +377,58 @@ public class ModelTrainMngCoreService {
public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId);
}
public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now());
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep1Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId());
entity.setStep1EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep2Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId());
entity.setStep2EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
@@ -76,4 +77,19 @@ public class ModelTrainJobEntity {
@Size(max = 100)
@Column(name = "locked_by", length = 100)
private String lockedBy;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,
this.modelId,
this.attemptNo,
this.statusCd,
this.exitCode,
this.errorMessage,
this.containerName,
this.paramsJson,
this.queuedDttm,
this.startedDttm,
this.finishedDttm);
}
}

View File

@@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class)))
Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class)))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))

View File

@@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
}

View File

@@ -1,34 +1,43 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final EntityManager em;
private final JPAQueryFactory queryFactory;
private JPAQueryFactory queryFactory() {
return new JPAQueryFactory(em);
public ModelTrainJobRepositoryImpl(EntityManager em) {
this.queryFactory = new JPAQueryFactory(em);
}
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override
public int findMaxAttemptNo(Long modelId) {
return 0;
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
return max != null ? max : 0;
}
/**
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
* 단순.
*/
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return Optional.empty();
}
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
@Override
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
return Optional.empty();
ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
return Optional.ofNullable(job);
}
}