학습실행 주석 추가
This commit is contained in:
@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
|
|||||||
return modelDetailRepository.getModelDetailSummary(uuid);
|
return modelDetailRepository.getModelDetailSummary(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 하이퍼 파리미터 요약정보
|
||||||
|
*
|
||||||
|
* @param uuid 모델마스터 uuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
||||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ public class ModelTrainJobCoreService {
|
|||||||
/** 실행 시작 처리 */
|
/** 실행 시작 처리 */
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markRunning(
|
public void markRunning(
|
||||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
Long jobId,
|
||||||
|
String containerName,
|
||||||
|
String logPath,
|
||||||
|
String lockedBy,
|
||||||
|
Integer totalEpoch,
|
||||||
|
String jobType) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
@@ -64,13 +69,19 @@ public class ModelTrainJobCoreService {
|
|||||||
job.setStartedDttm(ZonedDateTime.now());
|
job.setStartedDttm(ZonedDateTime.now());
|
||||||
job.setLockedDttm(ZonedDateTime.now());
|
job.setLockedDttm(ZonedDateTime.now());
|
||||||
job.setLockedBy(lockedBy);
|
job.setLockedBy(lockedBy);
|
||||||
|
job.setJobType(jobType);
|
||||||
|
|
||||||
if (totalEpoch != null) {
|
if (totalEpoch != null) {
|
||||||
job.setTotalEpoch(totalEpoch);
|
job.setTotalEpoch(totalEpoch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 성공 처리 */
|
/**
|
||||||
|
* 성공 처리
|
||||||
|
*
|
||||||
|
* @param jobId
|
||||||
|
* @param exitCode
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markSuccess(Long jobId, int exitCode) {
|
public void markSuccess(Long jobId, int exitCode) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
@@ -83,7 +94,13 @@ public class ModelTrainJobCoreService {
|
|||||||
job.setFinishedDttm(ZonedDateTime.now());
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 실패 처리 */
|
/**
|
||||||
|
* 실패 처리
|
||||||
|
*
|
||||||
|
* @param jobId
|
||||||
|
* @param exitCode
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
|
|||||||
@@ -384,7 +384,12 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -399,7 +404,12 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
|
|||||||
@Column(name = "current_epoch")
|
@Column(name = "current_epoch")
|
||||||
private Integer currentEpoch;
|
private Integer currentEpoch;
|
||||||
|
|
||||||
|
@Column(name = "job_type")
|
||||||
|
private String jobType;
|
||||||
|
|
||||||
public ModelTrainJobDto toDto() {
|
public ModelTrainJobDto toDto() {
|
||||||
return new ModelTrainJobDto(
|
return new ModelTrainJobDto(
|
||||||
this.id,
|
this.id,
|
||||||
|
|||||||
@@ -52,7 +52,14 @@ public class DockerTrainService {
|
|||||||
|
|
||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
|
|
||||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
/**
|
||||||
|
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||||
|
*
|
||||||
|
* @param req
|
||||||
|
* @param containerName
|
||||||
|
* @return
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||||
|
|
||||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||||
@@ -267,8 +274,7 @@ public class DockerTrainService {
|
|||||||
addArg(c, "--input-size", req.getInputSize());
|
addArg(c, "--input-size", req.getInputSize());
|
||||||
addArg(c, "--crop-size", req.getCropSize());
|
addArg(c, "--crop-size", req.getCropSize());
|
||||||
addArg(c, "--batch-size", req.getBatchSize());
|
addArg(c, "--batch-size", req.getBatchSize());
|
||||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
addArg(c, "--gpu-ids", req.getGpuIds()); // null
|
||||||
// addArg(c, "--gpus", req.getGpus());
|
|
||||||
addArg(c, "--lr", req.getLearningRate());
|
addArg(c, "--lr", req.getLearningRate());
|
||||||
addArg(c, "--backbone", req.getBackbone());
|
addArg(c, "--backbone", req.getBackbone());
|
||||||
addArg(c, "--epochs", req.getEpochs());
|
addArg(c, "--epochs", req.getEpochs());
|
||||||
@@ -342,6 +348,14 @@ public class DockerTrainService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||||
|
*
|
||||||
|
* @param containerName
|
||||||
|
* @param req
|
||||||
|
* @return
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
||||||
|
|
||||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||||
|
|||||||
@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
|
|||||||
@Value("${file.pt-path}")
|
@Value("${file.pt-path}")
|
||||||
private String ptPathDir;
|
private String ptPathDir;
|
||||||
|
|
||||||
/**
|
/** 결과 csv 파일 정보 등록 */
|
||||||
* 실행중인 profile
|
public void findTestValidMetricCsvFiles() {
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private boolean isLocalProfile() {
|
|
||||||
return "local".equalsIgnoreCase(profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
// @Scheduled(cron = "0 * * * * *")
|
|
||||||
public void findTestValidMetricCsvFiles() throws IOException {
|
|
||||||
// if (isLocalProfile()) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
List<ResponsePathDto> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||||
|
|||||||
@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
|
|||||||
@Value("${train.docker.responseDir}")
|
@Value("${train.docker.responseDir}")
|
||||||
private String responseDir;
|
private String responseDir;
|
||||||
|
|
||||||
/**
|
/** 결과 csv 파일 정보 등록 */
|
||||||
* 실행중인 profile
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private boolean isLocalProfile() {
|
|
||||||
return "local".equalsIgnoreCase(profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
// @Scheduled(cron = "0 * * * * *")
|
|
||||||
public void findTrainValidMetricCsvFiles() {
|
public void findTrainValidMetricCsvFiles() {
|
||||||
// if (isLocalProfile()) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
List<ResponsePathDto> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||||
|
|||||||
@@ -23,6 +23,14 @@ public class TestJobService {
|
|||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
private final DataSetCountersService dataSetCounters;
|
private final DataSetCountersService dataSetCounters;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 실행 예약 (QUEUE 등록)
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param uuid
|
||||||
|
* @param epoch
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||||
|
|
||||||
@@ -58,6 +66,11 @@ public class TestJobService {
|
|||||||
return jobId;
|
return jobId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 취소
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void cancel(Long modelId) {
|
public void cancel(Long modelId) {
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,12 @@ public class TrainJobService {
|
|||||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 실행 예약 (QUEUE 등록) */
|
/**
|
||||||
|
* 실행 예약 (QUEUE 등록)
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId) {
|
public Long enqueue(Long modelId) {
|
||||||
|
|
||||||
@@ -139,6 +144,13 @@ public class TrainJobService {
|
|||||||
modelTrainMngCoreService.markStopped(modelId);
|
modelTrainMngCoreService.markStopped(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 이어하기
|
||||||
|
*
|
||||||
|
* @param modelId 모델 id
|
||||||
|
* @param mode NONE 새로 시작, REQUIRE 이어하기
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||||
|
|
||||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||||
@@ -189,6 +201,12 @@ public class TrainJobService {
|
|||||||
REQUIRE // 이어하기
|
REQUIRE // 이어하기
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
|
||||||
|
*
|
||||||
|
* @param paramsJson
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||||
if (paramsJson == null) return null;
|
if (paramsJson == null) return null;
|
||||||
|
|
||||||
@@ -230,6 +248,12 @@ public class TrainJobService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
|
||||||
|
*
|
||||||
|
* @param modelUuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public UUID createTmpFile(UUID modelUuid) {
|
public UUID createTmpFile(UUID modelUuid) {
|
||||||
UUID tmpUuid = UUID.randomUUID();
|
UUID tmpUuid = UUID.randomUUID();
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
|
|||||||
import org.springframework.transaction.event.TransactionPhase;
|
import org.springframework.transaction.event.TransactionPhase;
|
||||||
import org.springframework.transaction.event.TransactionalEventListener;
|
import org.springframework.transaction.event.TransactionalEventListener;
|
||||||
|
|
||||||
|
/** job 실행 */
|
||||||
@Log4j2
|
@Log4j2
|
||||||
@Component
|
@Component
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@@ -54,6 +55,8 @@ public class TrainJobWorker {
|
|||||||
String containerName =
|
String containerName =
|
||||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||||
|
|
||||||
|
String type = isEval ? "TEST" : "TRAIN";
|
||||||
|
|
||||||
Integer totalEpoch = null;
|
Integer totalEpoch = null;
|
||||||
if (params.containsKey("totalEpoch")) {
|
if (params.containsKey("totalEpoch")) {
|
||||||
if (params.get("totalEpoch") != null) {
|
if (params.get("totalEpoch") != null) {
|
||||||
@@ -61,12 +64,15 @@ public class TrainJobWorker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
||||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
// 실행 시작 처리
|
||||||
|
modelTrainJobCoreService.markRunning(
|
||||||
|
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
|
||||||
log.info("[JOB] markRunning done jobId={}", jobId);
|
log.info("[JOB] markRunning done jobId={}", jobId);
|
||||||
try {
|
try {
|
||||||
TrainRunResult result;
|
TrainRunResult result;
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// step2 진행중 처리
|
||||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||||
String uuid = String.valueOf(params.get("uuid"));
|
String uuid = String.valueOf(params.get("uuid"));
|
||||||
int epoch = (int) params.get("epoch");
|
int epoch = (int) params.get("epoch");
|
||||||
@@ -81,11 +87,14 @@ public class TrainJobWorker {
|
|||||||
evalReq.setOutputFolder(outputFolder);
|
evalReq.setOutputFolder(outputFolder);
|
||||||
log.info("[JOB] selected test epoch={}", epoch);
|
log.info("[JOB] selected test epoch={}", epoch);
|
||||||
|
|
||||||
|
// 도커 실행 후 로그 수집
|
||||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
// step1 진행중 처리
|
||||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||||
|
// 도커 실행 후 로그 수집
|
||||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,24 +108,31 @@ public class TrainJobWorker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (result.getExitCode() == 0) {
|
if (result.getExitCode() == 0) {
|
||||||
|
// 성공 처리
|
||||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// step2 완료처리
|
||||||
modelTrainMngCoreService.markStep2Success(modelId);
|
modelTrainMngCoreService.markStep2Success(modelId);
|
||||||
|
// 결과 csv 파일 정보 등록
|
||||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||||
} else {
|
} else {
|
||||||
modelTrainMngCoreService.markStep1Success(modelId);
|
modelTrainMngCoreService.markStep1Success(modelId);
|
||||||
|
// 결과 csv 파일 정보 등록
|
||||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
||||||
|
// 실패 처리
|
||||||
modelTrainJobCoreService.markFailed(
|
modelTrainJobCoreService.markFailed(
|
||||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||||
} else {
|
} else {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,8 +141,10 @@ public class TrainJobWorker {
|
|||||||
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
||||||
|
|
||||||
if ("EVAL".equals(params.get("jobType"))) {
|
if ("EVAL".equals(params.get("jobType"))) {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||||
} else {
|
} else {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user