feat/training_260202 #136

Merged
teddy merged 5 commits from feat/training_260202 into develop 2026-02-24 15:11:12 +09:00
9 changed files with 197 additions and 12 deletions
Showing only changes of commit 7c5f07683e - Show all commits

View File

@@ -2,10 +2,8 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication @SpringBootApplication
@EnableScheduling @EnableScheduling
public class KamcoTrainingApplication { public class KamcoTrainingApplication {

View File

@@ -0,0 +1,27 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobStatusType implements EnumType {
QUEUED("대기중"),
RUNNING("실행중"),
SUCCESS("성공"),
FAILED("실패"),
CANCELED("취소");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobType implements EnumType {
TRAIN("학습"),
TEST("테스트");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "trainJobExecutor")
public Executor trainJobExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); // 동시에 4개 실행
executor.setMaxPoolSize(8); // 최대 8개
executor.setQueueCapacity(200); // 대기 큐
executor.setThreadNamePrefix("train-job-");
executor.initialize();
return executor;
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto; import com.kamco.cd.training.train.dto.ModelTrainJobDto;
@@ -9,6 +10,7 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2; import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@@ -61,7 +63,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("RUNNING"); job.setStatusCd("RUNNING");
job.setContainerName(containerName); job.setContainerName(containerName);
@@ -87,7 +89,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("SUCCESS"); job.setStatusCd("SUCCESS");
job.setExitCode(exitCode); job.setExitCode(exitCode);
@@ -106,7 +108,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("FAILED"); job.setStatusCd("FAILED");
job.setExitCode(exitCode); job.setExitCode(exitCode);
@@ -122,7 +124,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED"); job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
@@ -133,7 +135,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findByContainerName(containerName) .findByContainerName(containerName)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setCurrentEpoch(epoch); job.setCurrentEpoch(epoch);
@@ -143,4 +145,17 @@ public class ModelTrainJobCoreService {
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) { public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch); modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
} }
/**
* 실행중인 학습이 있는지 조회
*
* @return
*/
public ModelTrainJobDto findRunningJobs() {
ModelTrainJobEntity entity = modelTrainJobRepository.findRunningJobs().orElse(null);
if (entity == null) {
return null;
}
return entity.toDto();
}
} }

View File

@@ -11,4 +11,6 @@ public interface ModelTrainJobRepositoryCustom {
Optional<ModelTrainJobEntity> findByContainerName(String containerName); Optional<ModelTrainJobEntity> findByContainerName(String containerName);
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch); void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
Optional<ModelTrainJobEntity> findRunningJobs();
} }

View File

@@ -1,7 +1,10 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity; import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.JobType;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -21,7 +24,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
/** modelId의 attempt_no 최대값. (없으면 0) */ /** modelId의 attempt_no 최대값. (없으면 0) */
@Override @Override
public int findMaxAttemptNo(Long modelId) { public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; QModelTrainJobEntity j = modelTrainJobEntity;
Integer max = Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne(); queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
@@ -35,7 +38,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
*/ */
@Override @Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) { public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job = ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst(); queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
@@ -45,7 +48,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
@Override @Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) { public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job = ModelTrainJobEntity job =
queryFactory queryFactory
@@ -78,4 +81,20 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
.values(modelId, nextAttemptNo, jobId, epoch) .values(modelId, nextAttemptNo, jobId, epoch)
.execute(); .execute();
} }
@Override
public Optional<ModelTrainJobEntity> findRunningJobs() {
return Optional.ofNullable(
queryFactory
.select(modelTrainJobEntity)
.from(modelTrainJobEntity)
.where(
modelTrainJobEntity
.statusCd
.eq(JobStatusType.RUNNING.getId())
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
.orderBy(modelTrainJobEntity.id.desc())
.limit(1)
.fetchOne());
}
} }

View File

@@ -0,0 +1,76 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
/** 실행중 학습이 있을때 처리 */
@Component
@RequiredArgsConstructor
@Log4j2
@Transactional(readOnly = true)
public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
@EventListener(ApplicationReadyEvent.class)
public void recover() {
// RUNNING 중인 학습이 있는지 조회
ModelTrainJobDto runningJobs = modelTrainJobCoreService.findRunningJobs();
if (runningJobs == null) {
return;
}
String containerName = runningJobs.getContainerName();
try {
boolean containerAlive = isContainerRunning(containerName);
if (containerAlive) {
// 컨테이너 살아있으면 → RUNNING 유지
log.info("[RECOVERY] container still running: {}", containerName);
} else {
// 컨테이너 죽었으면 → FAILED 처리
log.info("[RECOVERY] container not found. mark FAILED: {}", containerName);
modelTrainJobCoreService.markFailed(
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_NOT_FOUND");
}
} catch (IOException e) {
log.error("[RECOVERY] container check failed. mark FAILED: {}", containerName, e);
modelTrainJobCoreService.markFailed(
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_CHECK_ERROR");
}
}
/**
* docker 실행중인지 확인하기
*
* @param containerName container name
* @return true, false
* @throws IOException
*/
private boolean isContainerRunning(String containerName) throws IOException {
ProcessBuilder pb =
new ProcessBuilder("docker", "inspect", "-f", "{{.State.Running}}", containerName);
Process p = pb.start();
BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line = br.readLine();
return "true".equals(line);
}
}

View File

@@ -30,10 +30,12 @@ public class TrainJobWorker {
private final DockerTrainService dockerTrainService; private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
@Async @Async("trainJobExecutor")
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) { public void handle(ModelTrainJobQueuedEvent event) {
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
Long jobId = event.getJobId(); Long jobId = event.getJobId();
ModelTrainJobDto job = ModelTrainJobDto job =
@@ -89,7 +91,6 @@ public class TrainJobWorker {
// 도커 실행 후 로그 수집 // 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq); result = dockerTrainService.runEvalSync(containerName, evalReq);
} else { } else {
// step1 진행중 처리 // step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId); modelTrainMngCoreService.markStep1InProgress(modelId, jobId);