학습실행 주석 추가
This commit is contained in:
@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
|
||||
return modelDetailRepository.getModelDetailSummary(uuid);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼 파리미터 요약정보
|
||||
*
|
||||
* @param uuid 모델마스터 uuid
|
||||
* @return
|
||||
*/
|
||||
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,12 @@ public class ModelTrainJobCoreService {
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(
|
||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
||||
Long jobId,
|
||||
String containerName,
|
||||
String logPath,
|
||||
String lockedBy,
|
||||
Integer totalEpoch,
|
||||
String jobType) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
@@ -64,13 +69,19 @@ public class ModelTrainJobCoreService {
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
job.setJobType(jobType);
|
||||
|
||||
if (totalEpoch != null) {
|
||||
job.setTotalEpoch(totalEpoch);
|
||||
}
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
/**
|
||||
* 성공 처리
|
||||
*
|
||||
* @param jobId
|
||||
* @param exitCode
|
||||
*/
|
||||
@Transactional
|
||||
public void markSuccess(Long jobId, int exitCode) {
|
||||
ModelTrainJobEntity job =
|
||||
@@ -83,7 +94,13 @@ public class ModelTrainJobCoreService {
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 실패 처리 */
|
||||
/**
|
||||
* 실패 처리
|
||||
*
|
||||
* @param jobId
|
||||
* @param exitCode
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||
ModelTrainJobEntity job =
|
||||
|
||||
@@ -384,7 +384,12 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/**
|
||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*
|
||||
* @param modelId
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
@@ -399,7 +404,12 @@ public class ModelTrainMngCoreService {
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/**
|
||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*
|
||||
* @param modelId
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
|
||||
@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
|
||||
@Column(name = "current_epoch")
|
||||
private Integer currentEpoch;
|
||||
|
||||
@Column(name = "job_type")
|
||||
private String jobType;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
|
||||
@@ -52,7 +52,14 @@ public class DockerTrainService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
/**
|
||||
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||
*
|
||||
* @param req
|
||||
* @param containerName
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
@@ -267,8 +274,7 @@ public class DockerTrainService {
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
addArg(c, "--batch-size", req.getBatchSize());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
||||
// addArg(c, "--gpus", req.getGpus());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds()); // null
|
||||
addArg(c, "--lr", req.getLearningRate());
|
||||
addArg(c, "--backbone", req.getBackbone());
|
||||
addArg(c, "--epochs", req.getEpochs());
|
||||
@@ -342,6 +348,14 @@ public class DockerTrainService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||
*
|
||||
* @param containerName
|
||||
* @param req
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||
|
||||
@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
|
||||
@Value("${file.pt-path}")
|
||||
private String ptPathDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() throws IOException {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
/** 결과 csv 파일 정보 등록 */
|
||||
public void findTestValidMetricCsvFiles() {
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||
|
||||
@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
/** 결과 csv 파일 정보 등록 */
|
||||
public void findTrainValidMetricCsvFiles() {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
@@ -23,6 +23,14 @@ public class TestJobService {
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
/**
|
||||
* 실행 예약 (QUEUE 등록)
|
||||
*
|
||||
* @param modelId
|
||||
* @param uuid
|
||||
* @param epoch
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
|
||||
@@ -58,6 +66,11 @@ public class TestJobService {
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 취소
|
||||
*
|
||||
* @param modelId
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
|
||||
@@ -47,7 +47,12 @@ public class TrainJobService {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
/** 실행 예약 (QUEUE 등록) */
|
||||
/**
|
||||
* 실행 예약 (QUEUE 등록)
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
@@ -139,6 +144,13 @@ public class TrainJobService {
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 이어하기
|
||||
*
|
||||
* @param modelId 모델 id
|
||||
* @param mode NONE 새로 시작, REQUIRE 이어하기
|
||||
* @return
|
||||
*/
|
||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
@@ -189,6 +201,12 @@ public class TrainJobService {
|
||||
REQUIRE // 이어하기
|
||||
}
|
||||
|
||||
/**
|
||||
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
|
||||
*
|
||||
* @param paramsJson
|
||||
* @return
|
||||
*/
|
||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null) return null;
|
||||
|
||||
@@ -230,6 +248,12 @@ public class TrainJobService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
|
||||
*
|
||||
* @param modelUuid
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public UUID createTmpFile(UUID modelUuid) {
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
|
||||
@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
/** job 실행 */
|
||||
@Log4j2
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@@ -54,6 +55,8 @@ public class TrainJobWorker {
|
||||
String containerName =
|
||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||
|
||||
String type = isEval ? "TEST" : "TRAIN";
|
||||
|
||||
Integer totalEpoch = null;
|
||||
if (params.containsKey("totalEpoch")) {
|
||||
if (params.get("totalEpoch") != null) {
|
||||
@@ -61,12 +64,15 @@ public class TrainJobWorker {
|
||||
}
|
||||
}
|
||||
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
||||
// 실행 시작 처리
|
||||
modelTrainJobCoreService.markRunning(
|
||||
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
|
||||
log.info("[JOB] markRunning done jobId={}", jobId);
|
||||
try {
|
||||
TrainRunResult result;
|
||||
|
||||
if (isEval) {
|
||||
// step2 진행중 처리
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
String uuid = String.valueOf(params.get("uuid"));
|
||||
int epoch = (int) params.get("epoch");
|
||||
@@ -81,11 +87,14 @@ public class TrainJobWorker {
|
||||
evalReq.setOutputFolder(outputFolder);
|
||||
log.info("[JOB] selected test epoch={}", epoch);
|
||||
|
||||
// 도커 실행 후 로그 수집
|
||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||
|
||||
} else {
|
||||
// step1 진행중 처리
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||
// 도커 실행 후 로그 수집
|
||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||
}
|
||||
|
||||
@@ -99,24 +108,31 @@ public class TrainJobWorker {
|
||||
}
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
// 성공 처리
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
|
||||
if (isEval) {
|
||||
// step2 완료처리
|
||||
modelTrainMngCoreService.markStep2Success(modelId);
|
||||
// 결과 csv 파일 정보 등록
|
||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1Success(modelId);
|
||||
// 결과 csv 파일 정보 등록
|
||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||
}
|
||||
|
||||
} else {
|
||||
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
||||
// 실패 처리
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
|
||||
if (isEval) {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||
} else {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
}
|
||||
@@ -125,8 +141,10 @@ public class TrainJobWorker {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
||||
|
||||
if ("EVAL".equals(params.get("jobType"))) {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||
} else {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user