Merge pull request 'feat/training_260202' (#136) from feat/training_260202 into develop

Reviewed-on: #136
This commit was merged in pull request #136.
This commit is contained in:
2026-02-24 15:11:12 +09:00
16 changed files with 218 additions and 25 deletions

View File

@@ -2,10 +2,8 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication
@EnableScheduling
public class KamcoTrainingApplication {

View File

@@ -0,0 +1,27 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobStatusType implements EnumType {
QUEUED("대기중"),
RUNNING("실행중"),
SUCCESS("성공"),
FAILED("실패"),
CANCELED("취소");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobType implements EnumType {
TRAIN("학습"),
TEST("테스트");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "trainJobExecutor")
public Executor trainJobExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); // 동시에 4개 실행
executor.setMaxPoolSize(8); // 최대 8개
executor.setQueueCapacity(200); // 대기 큐
executor.setThreadNamePrefix("train-job-");
executor.initialize();
return executor;
}
}

View File

@@ -115,7 +115,7 @@ public class HyperParamDto {
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private String memo;
private Long totalUseCnt;
private Integer totalUseCnt;
}
@Getter

View File

@@ -94,7 +94,7 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
// 임시파일 생성
// 데이터셋 임시파일 생성
trainJobService.createTmpFile(modelUuid);
return modelUuid;
}

View File

@@ -34,7 +34,6 @@ public class HyperParamCoreService {
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq);
// user
@@ -172,7 +171,7 @@ public class HyperParamCoreService {
} else {
entity.setCropSize("256,256");
}
// entity.setCropSize(src.getCropSize());
entity.setCropSize(src.getCropSize());
// Important
entity.setModelType(model); // 20250212 modeltype추가

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
@@ -9,6 +10,7 @@ import java.util.Objects;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -61,7 +63,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
@@ -87,7 +89,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
@@ -106,7 +108,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
@@ -122,7 +124,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
@@ -133,7 +135,7 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setCurrentEpoch(epoch);
@@ -143,4 +145,17 @@ public class ModelTrainJobCoreService {
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
}
/**
* 실행중인 학습이 있는지 조회
*
* @return
*/
public ModelTrainJobDto findRunningJobs() {
ModelTrainJobEntity entity = modelTrainJobRepository.findRunningJobs().orElse(null);
if (entity == null) {
return null;
}
return entity.toDto();
}
}

View File

@@ -104,6 +104,12 @@ public class ModelTrainMngCoreService {
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
}
// 하이퍼 파라미터 사용 횟수 업데이트
hyperParamEntity.setTotalUseCnt(
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
// 최근 사용일시 업데이트
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
String modelVer =
String.join(

View File

@@ -310,6 +310,9 @@ public class ModelHyperParamEntity {
@Column(name = "default_param")
private Boolean isDefault = false;
@Column(name = "total_use_cnt")
private Integer totalUseCnt = 0;
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.modelType,

View File

@@ -129,7 +129,8 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.memo))
modelHyperParamEntity.memo,
modelHyperParamEntity.totalUseCnt))
.from(modelHyperParamEntity)
.where(builder);
@@ -154,6 +155,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
asc
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" ->
query.orderBy(
asc
? modelHyperParamEntity.totalUseCnt.asc()
: modelHyperParamEntity.totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}

View File

@@ -11,4 +11,6 @@ public interface ModelTrainJobRepositoryCustom {
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
Optional<ModelTrainJobEntity> findRunningJobs();
}

View File

@@ -1,7 +1,10 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.JobType;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -21,7 +24,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override
public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
QModelTrainJobEntity j = modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
@@ -35,7 +38,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
*/
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
@@ -45,7 +48,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
@@ -78,4 +81,20 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
.values(modelId, nextAttemptNo, jobId, epoch)
.execute();
}
@Override
public Optional<ModelTrainJobEntity> findRunningJobs() {
return Optional.ofNullable(
queryFactory
.select(modelTrainJobEntity)
.from(modelTrainJobEntity)
.where(
modelTrainJobEntity
.statusCd
.eq(JobStatusType.RUNNING.getId())
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
.orderBy(modelTrainJobEntity.id.desc())
.limit(1)
.fetchOne());
}
}

View File

@@ -0,0 +1,76 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
/** 실행중 학습이 있을때 처리 */
@Component
@RequiredArgsConstructor
@Log4j2
@Transactional(readOnly = true)
public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
@EventListener(ApplicationReadyEvent.class)
public void recover() {
// RUNNING 중인 학습이 있는지 조회
ModelTrainJobDto runningJobs = modelTrainJobCoreService.findRunningJobs();
if (runningJobs == null) {
return;
}
String containerName = runningJobs.getContainerName();
try {
boolean containerAlive = isContainerRunning(containerName);
if (containerAlive) {
// 컨테이너 살아있으면 → RUNNING 유지
log.info("[RECOVERY] container still running: {}", containerName);
} else {
// 컨테이너 죽었으면 → FAILED 처리
log.info("[RECOVERY] container not found. mark FAILED: {}", containerName);
modelTrainJobCoreService.markFailed(
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_NOT_FOUND");
}
} catch (IOException e) {
log.error("[RECOVERY] container check failed. mark FAILED: {}", containerName, e);
modelTrainJobCoreService.markFailed(
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_CHECK_ERROR");
}
}
/**
* docker 실행중인지 확인하기
*
* @param containerName container name
* @return true, false
* @throws IOException
*/
private boolean isContainerRunning(String containerName) throws IOException {
ProcessBuilder pb =
new ProcessBuilder("docker", "inspect", "-f", "{{.State.Running}}", containerName);
Process p = pb.start();
BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line = br.readLine();
return "true".equals(line);
}
}

View File

@@ -296,14 +296,8 @@ public class TrainJobService {
e);
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
throw new IllegalStateException(
"tmp dataset build failed: modelUuid="
+ modelUuid
+ ", modelId="
+ modelId
+ ", tmpRaw="
+ raw,
e);
throw new CustomApiException(
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
}
return modelUuid;
}

View File

@@ -30,10 +30,12 @@ public class TrainJobWorker {
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async
@Async("trainJobExecutor")
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
Long jobId = event.getJobId();
ModelTrainJobDto job =
@@ -89,7 +91,6 @@ public class TrainJobWorker {
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
} else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);