Merge pull request 'feat/training_260202' (#136) from feat/training_260202 into develop
Reviewed-on: #136
This commit was merged in pull request #136.
This commit is contained in:
@@ -2,10 +2,8 @@ package com.kamco.cd.training;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
|
||||
@EnableAsync
|
||||
@SpringBootApplication
|
||||
@EnableScheduling
|
||||
public class KamcoTrainingApplication {
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package com.kamco.cd.training.common.enums;
|
||||
|
||||
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum JobStatusType implements EnumType {
|
||||
QUEUED("대기중"),
|
||||
RUNNING("실행중"),
|
||||
SUCCESS("성공"),
|
||||
FAILED("실패"),
|
||||
CANCELED("취소");
|
||||
|
||||
private final String desc;
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return name();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getText() {
|
||||
return desc;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.kamco.cd.training.common.enums;
|
||||
|
||||
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum JobType implements EnumType {
|
||||
TRAIN("학습"),
|
||||
TEST("테스트");
|
||||
|
||||
private final String desc;
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return name();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getText() {
|
||||
return desc;
|
||||
}
|
||||
}
|
||||
23
src/main/java/com/kamco/cd/training/config/AsyncConfig.java
Normal file
23
src/main/java/com/kamco/cd/training/config/AsyncConfig.java
Normal file
@@ -0,0 +1,23 @@
|
||||
package com.kamco.cd.training.config;
|
||||
|
||||
import java.util.concurrent.Executor;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
|
||||
@Configuration
|
||||
@EnableAsync
|
||||
public class AsyncConfig {
|
||||
|
||||
@Bean(name = "trainJobExecutor")
|
||||
public Executor trainJobExecutor() {
|
||||
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||
executor.setCorePoolSize(4); // 동시에 4개 실행
|
||||
executor.setMaxPoolSize(8); // 최대 8개
|
||||
executor.setQueueCapacity(200); // 대기 큐
|
||||
executor.setThreadNamePrefix("train-job-");
|
||||
executor.initialize();
|
||||
return executor;
|
||||
}
|
||||
}
|
||||
@@ -115,7 +115,7 @@ public class HyperParamDto {
|
||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||
private String memo;
|
||||
private Long totalUseCnt;
|
||||
private Integer totalUseCnt;
|
||||
}
|
||||
|
||||
@Getter
|
||||
|
||||
@@ -94,7 +94,7 @@ public class ModelTrainMngService {
|
||||
// 모델 config 저장
|
||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||
|
||||
// 임시파일 생성
|
||||
// 데이터셋 임시파일 생성
|
||||
trainJobService.createTmpFile(modelUuid);
|
||||
return modelUuid;
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ public class HyperParamCoreService {
|
||||
|
||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||
entity.setHyperVer(firstVersion);
|
||||
|
||||
applyHyperParam(entity, createReq);
|
||||
|
||||
// user
|
||||
@@ -172,7 +171,7 @@ public class HyperParamCoreService {
|
||||
} else {
|
||||
entity.setCropSize("256,256");
|
||||
}
|
||||
// entity.setCropSize(src.getCropSize());
|
||||
entity.setCropSize(src.getCropSize());
|
||||
|
||||
// Important
|
||||
entity.setModelType(model); // 20250212 modeltype추가
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
@@ -9,6 +10,7 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@@ -61,7 +63,7 @@ public class ModelTrainJobCoreService {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
job.setStatusCd("RUNNING");
|
||||
job.setContainerName(containerName);
|
||||
@@ -87,7 +89,7 @@ public class ModelTrainJobCoreService {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
job.setStatusCd("SUCCESS");
|
||||
job.setExitCode(exitCode);
|
||||
@@ -106,7 +108,7 @@ public class ModelTrainJobCoreService {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
job.setStatusCd("FAILED");
|
||||
job.setExitCode(exitCode);
|
||||
@@ -122,7 +124,7 @@ public class ModelTrainJobCoreService {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
@@ -133,7 +135,7 @@ public class ModelTrainJobCoreService {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findByContainerName(containerName)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
job.setCurrentEpoch(epoch);
|
||||
|
||||
@@ -143,4 +145,17 @@ public class ModelTrainJobCoreService {
|
||||
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
|
||||
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
|
||||
}
|
||||
|
||||
/**
|
||||
* 실행중인 학습이 있는지 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public ModelTrainJobDto findRunningJobs() {
|
||||
ModelTrainJobEntity entity = modelTrainJobRepository.findRunningJobs().orElse(null);
|
||||
if (entity == null) {
|
||||
return null;
|
||||
}
|
||||
return entity.toDto();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,12 @@ public class ModelTrainMngCoreService {
|
||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
|
||||
}
|
||||
// 하이퍼 파라미터 사용 횟수 업데이트
|
||||
hyperParamEntity.setTotalUseCnt(
|
||||
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
|
||||
|
||||
// 최근 사용일시 업데이트
|
||||
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
|
||||
|
||||
String modelVer =
|
||||
String.join(
|
||||
|
||||
@@ -310,6 +310,9 @@ public class ModelHyperParamEntity {
|
||||
@Column(name = "default_param")
|
||||
private Boolean isDefault = false;
|
||||
|
||||
@Column(name = "total_use_cnt")
|
||||
private Integer totalUseCnt = 0;
|
||||
|
||||
public HyperParamDto.Basic toDto() {
|
||||
return new HyperParamDto.Basic(
|
||||
this.modelType,
|
||||
|
||||
@@ -129,7 +129,8 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
modelHyperParamEntity.hyperVer,
|
||||
modelHyperParamEntity.createdDttm,
|
||||
modelHyperParamEntity.lastUsedDttm,
|
||||
modelHyperParamEntity.memo))
|
||||
modelHyperParamEntity.memo,
|
||||
modelHyperParamEntity.totalUseCnt))
|
||||
.from(modelHyperParamEntity)
|
||||
.where(builder);
|
||||
|
||||
@@ -154,6 +155,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
asc
|
||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||
case "totalUseCnt" ->
|
||||
query.orderBy(
|
||||
asc
|
||||
? modelHyperParamEntity.totalUseCnt.asc()
|
||||
: modelHyperParamEntity.totalUseCnt.desc());
|
||||
|
||||
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
||||
}
|
||||
|
||||
@@ -11,4 +11,6 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||
|
||||
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
|
||||
|
||||
Optional<ModelTrainJobEntity> findRunningJobs();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.JobStatusType;
|
||||
import com.kamco.cd.training.common.enums.JobType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
@@ -21,7 +24,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||
|
||||
Integer max =
|
||||
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||
@@ -35,7 +38,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
*/
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||
@@ -45,7 +48,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory
|
||||
@@ -78,4 +81,20 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
.values(modelId, nextAttemptNo, jobId, epoch)
|
||||
.execute();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findRunningJobs() {
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
.select(modelTrainJobEntity)
|
||||
.from(modelTrainJobEntity)
|
||||
.where(
|
||||
modelTrainJobEntity
|
||||
.statusCd
|
||||
.eq(JobStatusType.RUNNING.getId())
|
||||
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
|
||||
.orderBy(modelTrainJobEntity.id.desc())
|
||||
.limit(1)
|
||||
.fetchOne());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
/** 실행중 학습이 있을때 처리 */
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@Log4j2
|
||||
@Transactional(readOnly = true)
|
||||
public class JobRecoveryOnStartupService {
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
public void recover() {
|
||||
// RUNNING 중인 학습이 있는지 조회
|
||||
ModelTrainJobDto runningJobs = modelTrainJobCoreService.findRunningJobs();
|
||||
|
||||
if (runningJobs == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
String containerName = runningJobs.getContainerName();
|
||||
|
||||
try {
|
||||
boolean containerAlive = isContainerRunning(containerName);
|
||||
|
||||
if (containerAlive) {
|
||||
// 컨테이너 살아있으면 → RUNNING 유지
|
||||
log.info("[RECOVERY] container still running: {}", containerName);
|
||||
|
||||
} else {
|
||||
// 컨테이너 죽었으면 → FAILED 처리
|
||||
log.info("[RECOVERY] container not found. mark FAILED: {}", containerName);
|
||||
|
||||
modelTrainJobCoreService.markFailed(
|
||||
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_NOT_FOUND");
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("[RECOVERY] container check failed. mark FAILED: {}", containerName, e);
|
||||
|
||||
modelTrainJobCoreService.markFailed(
|
||||
runningJobs.getId(), null, "SERVER_RESTART_CONTAINER_CHECK_ERROR");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* docker 실행중인지 확인하기
|
||||
*
|
||||
* @param containerName container name
|
||||
* @return true, false
|
||||
* @throws IOException
|
||||
*/
|
||||
private boolean isContainerRunning(String containerName) throws IOException {
|
||||
|
||||
ProcessBuilder pb =
|
||||
new ProcessBuilder("docker", "inspect", "-f", "{{.State.Running}}", containerName);
|
||||
|
||||
Process p = pb.start();
|
||||
BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()));
|
||||
|
||||
String line = br.readLine();
|
||||
return "true".equals(line);
|
||||
}
|
||||
}
|
||||
@@ -296,14 +296,8 @@ public class TrainJobService {
|
||||
e);
|
||||
|
||||
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
|
||||
throw new IllegalStateException(
|
||||
"tmp dataset build failed: modelUuid="
|
||||
+ modelUuid
|
||||
+ ", modelId="
|
||||
+ modelId
|
||||
+ ", tmpRaw="
|
||||
+ raw,
|
||||
e);
|
||||
throw new CustomApiException(
|
||||
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
|
||||
}
|
||||
return modelUuid;
|
||||
}
|
||||
|
||||
@@ -30,10 +30,12 @@ public class TrainJobWorker {
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Async
|
||||
@Async("trainJobExecutor")
|
||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||
public void handle(ModelTrainJobQueuedEvent event) {
|
||||
|
||||
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
|
||||
|
||||
Long jobId = event.getJobId();
|
||||
|
||||
ModelTrainJobDto job =
|
||||
@@ -89,7 +91,6 @@ public class TrainJobWorker {
|
||||
|
||||
// 도커 실행 후 로그 수집
|
||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||
|
||||
} else {
|
||||
// step1 진행중 처리
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
|
||||
Reference in New Issue
Block a user