학습실행 주석 추가

This commit is contained in:
2026-02-23 12:30:54 +09:00
parent c2978e41c2
commit 8f75b16dc6
10 changed files with 118 additions and 37 deletions

View File

@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getModelDetailSummary(uuid);
}
/**
* 하이퍼 파리미터 요약정보
*
* @param uuid 모델마스터 uuid
* @return
*/
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}

View File

@@ -52,7 +52,12 @@ public class ModelTrainJobCoreService {
/** 실행 시작 처리 */
@Transactional
public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
Long jobId,
String containerName,
String logPath,
String lockedBy,
Integer totalEpoch,
String jobType) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
@@ -64,13 +69,19 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
job.setJobType(jobType);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/** 성공 처리 */
/**
* 성공 처리
*
* @param jobId
* @param exitCode
*/
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
@@ -83,7 +94,13 @@ public class ModelTrainJobCoreService {
job.setFinishedDttm(ZonedDateTime.now());
}
/** 실패 처리 */
/**
* 실패 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =

View File

@@ -384,7 +384,12 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
@@ -399,7 +404,12 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now());
}
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =

View File

@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
@Column(name = "current_epoch")
private Integer currentEpoch;
@Column(name = "job_type")
private String jobType;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,

View File

@@ -52,7 +52,14 @@ public class DockerTrainService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param req
* @param containerName
* @return
* @throws Exception
*/
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req);
@@ -267,8 +274,7 @@ public class DockerTrainService {
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--gpu-ids", req.getGpuIds()); // null
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
@@ -342,6 +348,14 @@ public class DockerTrainService {
}
}
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param containerName
* @param req
* @return
* @throws Exception
*/
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);

View File

@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
@Value("${file.pt-path}")
private String ptPathDir;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() throws IOException {
// if (isLocalProfile()) {
// return;
// }
/** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() {
List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();

View File

@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
/** 결과 csv 파일 정보 등록 */
public void findTrainValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();

View File

@@ -23,6 +23,14 @@ public class TestJobService {
private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @param uuid
* @param epoch
* @return
*/
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
@@ -58,6 +66,11 @@ public class TestJobService {
return jobId;
}
/**
* 취소
*
* @param modelId
*/
@Transactional
public void cancel(Long modelId) {

View File

@@ -47,7 +47,12 @@ public class TrainJobService {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/** 실행 예약 (QUEUE 등록) */
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @return
*/
@Transactional
public Long enqueue(Long modelId) {
@@ -139,6 +144,13 @@ public class TrainJobService {
modelTrainMngCoreService.markStopped(modelId);
}
/**
* 학습 이어하기
*
* @param modelId 모델 id
* @param mode NONE 새로 시작, REQUIRE 이어하기
* @return
*/
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
@@ -189,6 +201,12 @@ public class TrainJobService {
REQUIRE // 이어하기
}
/**
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
*
* @param paramsJson
* @return
*/
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
@@ -230,6 +248,12 @@ public class TrainJobService {
}
}
/**
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
*
* @param modelUuid
* @return
*/
@Transactional
public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID();

View File

@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
/** job 실행 */
@Log4j2
@Component
@RequiredArgsConstructor
@@ -54,6 +55,8 @@ public class TrainJobWorker {
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
String type = isEval ? "TEST" : "TRAIN";
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
@@ -61,12 +64,15 @@ public class TrainJobWorker {
}
}
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
// 실행 시작 처리
modelTrainJobCoreService.markRunning(
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
log.info("[JOB] markRunning done jobId={}", jobId);
try {
TrainRunResult result;
if (isEval) {
// step2 진행중 처리
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
@@ -81,11 +87,14 @@ public class TrainJobWorker {
evalReq.setOutputFolder(outputFolder);
log.info("[JOB] selected test epoch={}", epoch);
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
} else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params);
// 도커 실행 후 로그 수집
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
@@ -99,24 +108,31 @@ public class TrainJobWorker {
}
if (result.getExitCode() == 0) {
// 성공 처리
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) {
// step2 완료처리
modelTrainMngCoreService.markStep2Success(modelId);
// 결과 csv 파일 정보 등록
modelTestMetricsJobService.findTestValidMetricCsvFiles();
} else {
modelTrainMngCoreService.markStep1Success(modelId);
// 결과 csv 파일 정보 등록
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
}
} else {
String failMsg = result.getStatus() + "\n" + result.getLogs();
// 실패 처리
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
}
@@ -125,8 +141,10 @@ public class TrainJobWorker {
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}