테스트 실행 추가 #40
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@@ -20,12 +21,12 @@ public class ModelTrainJobCoreService {
|
||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId);
|
||||
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobEntity> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId);
|
||||
public Optional<ModelTrainJobDto> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
/** QUEUED Job 생성 */
|
||||
@@ -95,7 +96,7 @@ public class ModelTrainJobCoreService {
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("CANCELED");
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainMngCoreService {
|
||||
|
||||
private final ModelMngRepository modelMngRepository;
|
||||
private final ModelDatasetRepository modelDatasetRepository;
|
||||
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||
@@ -323,7 +324,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
@@ -332,7 +333,25 @@ public class ModelTrainMngCoreService {
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
@@ -358,4 +377,58 @@ public class ModelTrainMngCoreService {
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return modelMngRepository.findTrainRunRequest(modelId);
|
||||
}
|
||||
|
||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep1Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void markStep2Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
@@ -76,4 +77,19 @@ public class ModelTrainJobEntity {
|
||||
@Size(max = 100)
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
this.modelId,
|
||||
this.attemptNo,
|
||||
this.statusCd,
|
||||
this.exitCode,
|
||||
this.errorMessage,
|
||||
this.containerName,
|
||||
this.paramsJson,
|
||||
this.queuedDttm,
|
||||
this.startedDttm,
|
||||
this.finishedDttm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
modelHyperParamEntity.contrastRange,
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class)))
|
||||
Expressions.nullExpression(Integer.class),
|
||||
Expressions.nullExpression(String.class)))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
|
||||
@@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
|
||||
}
|
||||
|
||||
@@ -1,34 +1,43 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import jakarta.persistence.EntityManager;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||
|
||||
private final EntityManager em;
|
||||
private final JPAQueryFactory queryFactory;
|
||||
|
||||
private JPAQueryFactory queryFactory() {
|
||||
return new JPAQueryFactory(em);
|
||||
public ModelTrainJobRepositoryImpl(EntityManager em) {
|
||||
this.queryFactory = new JPAQueryFactory(em);
|
||||
}
|
||||
|
||||
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return 0;
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
Integer max =
|
||||
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||
|
||||
return max != null ? max : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
|
||||
* 단순.
|
||||
*/
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
return Optional.empty();
|
||||
}
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
|
||||
return Optional.empty();
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.TestJobService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
@@ -12,6 +13,7 @@ import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@@ -22,6 +24,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
private final TestJobService testJobService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
@@ -36,7 +39,7 @@ public class TrainApiController {
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@RequestMapping("/run/{uuid}")
|
||||
@PostMapping("/run/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
@@ -45,4 +48,120 @@ public class TrainApiController {
|
||||
trainJobService.enqueue(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "재실행 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/restart/{uuid}")
|
||||
public ApiResponseDto<String> restart(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
Long jobId = trainJobService.restart(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "이어하기 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/resume/{uuid}")
|
||||
public ApiResponseDto<String> resume(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
Long jobId = trainJobService.resume(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "취소 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/cancel/{uuid}")
|
||||
public ApiResponseDto<String> cancel(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.cancel(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "test 실행", description = "test 실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "test 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/test/run/{epoch}/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
testJobService.enqueue(modelId, uuid, epoch);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "취소 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/test/cancel/{uuid}")
|
||||
public ApiResponseDto<String> cancelTest(
|
||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
testJobService.cancel(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EvalRunRequest {
|
||||
private String uuid;
|
||||
private int epoch; // best_changed_fscore_epoch_1.pth
|
||||
private Integer timeoutSeconds;
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public class ModelTrainJobDto {
|
||||
|
||||
private Long id;
|
||||
private Long modelId;
|
||||
private Integer attemptNo;
|
||||
private String statusCd;
|
||||
private Integer exitCode;
|
||||
private String errorMessage;
|
||||
private String containerName;
|
||||
private Map<String, Object> paramsJson;
|
||||
private ZonedDateTime queuedDttm;
|
||||
private ZonedDateTime startedDttm;
|
||||
private ZonedDateTime finishedDttm;
|
||||
}
|
||||
@@ -79,4 +79,5 @@ public class TrainRunRequest {
|
||||
// 실행 타임아웃
|
||||
// ========================
|
||||
private Integer timeoutSeconds;
|
||||
private String resumeFrom;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.io.BufferedReader;
|
||||
@@ -7,7 +8,6 @@ import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -40,53 +40,72 @@ public class DockerTrainService {
|
||||
private boolean ipcHost;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
// 실행 식별용 jobId 생성
|
||||
String jobId = UUID.randomUUID().toString().substring(0, 8);
|
||||
|
||||
// 컨테이너 이름 생성 (중복 방지 목적)
|
||||
String containerName = containerPrefix + "-" + jobId;
|
||||
|
||||
// docker run 명령어 조립
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
|
||||
// 프로세스 실행
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
|
||||
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
// 실행 로그 수집
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder log = new StringBuilder();
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
new BufferedReader(
|
||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
synchronized (log) {
|
||||
log.append(line).append('\n');
|
||||
}
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
|
||||
// 지정된 timeout 내에 종료 대기
|
||||
int timeoutSeconds = 7200; // 기본 2시간
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
|
||||
int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
// 타임아웃 발생 시 컨테이너 강제 제거
|
||||
// docker run 프로세스도 같이 끊어야 readLine이 풀림
|
||||
p.destroy();
|
||||
if (!p.waitFor(2, TimeUnit.SECONDS)) {
|
||||
p.destroyForcibly();
|
||||
}
|
||||
killContainer(containerName);
|
||||
|
||||
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(
|
||||
null, // jobId (없으면 null)
|
||||
containerName,
|
||||
-1,
|
||||
"TIMEOUT",
|
||||
logs);
|
||||
}
|
||||
|
||||
// 종료 코드 확인 (0=정상)
|
||||
int exit = p.exitValue();
|
||||
|
||||
return new TrainRunResult(
|
||||
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
|
||||
// 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
|
||||
logThread.join(500);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -205,6 +224,7 @@ public class DockerTrainService {
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
addArg(c, "--resume-from", req.getResumeFrom());
|
||||
return c;
|
||||
}
|
||||
|
||||
@@ -218,7 +238,7 @@ public class DockerTrainService {
|
||||
}
|
||||
|
||||
/** 컨테이너 강제 종료 및 제거 */
|
||||
private void killContainer(String containerName) {
|
||||
public void killContainer(String containerName) {
|
||||
try {
|
||||
new ProcessBuilder("docker", "rm", "-f", containerName)
|
||||
.redirectErrorStream(true)
|
||||
@@ -227,4 +247,100 @@ public class DockerTrainService {
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
StringBuilder log = new StringBuilder();
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(
|
||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
synchronized (log) {
|
||||
log.append(line).append('\n');
|
||||
}
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
});
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
|
||||
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
p.destroyForcibly();
|
||||
killContainer(containerName);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
|
||||
}
|
||||
|
||||
int exit = p.exitValue();
|
||||
logThread.join(500);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||
}
|
||||
|
||||
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
|
||||
|
||||
String uuid = req.getUuid();
|
||||
Integer epoch = req.getEpoch();
|
||||
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
|
||||
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
|
||||
|
||||
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
c.add("--name");
|
||||
c.add(containerName);
|
||||
c.add("--rm");
|
||||
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
if (ipcHost) c.add("--ipc=host");
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
c.add(image);
|
||||
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
|
||||
|
||||
c.add("--dataset_dir");
|
||||
c.add("/data/" + uuid);
|
||||
|
||||
c.add("--model");
|
||||
c.add("/checkpoints/" + uuid + "/" + modelFile);
|
||||
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class TestJobService {
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
|
||||
// 마스터 확인
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||
params.put("jobType", "EVAL");
|
||||
params.put("uuid", uuid);
|
||||
params.put("epoch", epoch);
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, params, ZonedDateTime.now());
|
||||
|
||||
// step2 시작으로 마킹
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
return jobId;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long jobId = master.getCurrentAttemptId();
|
||||
if (jobId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
var job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||
|
||||
String containerName = job.getContainerName();
|
||||
|
||||
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||
if (containerName != null && !containerName.isBlank()) {
|
||||
dockerTrainService.killContainer(containerName);
|
||||
}
|
||||
|
||||
// 2) 상태 업데이트 (항상 수행)
|
||||
modelTrainJobCoreService.markCanceled(jobId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
}
|
||||
@@ -7,10 +7,14 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
@@ -22,9 +26,14 @@ public class TrainJobService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
public Long getModelIdByUuid(UUID uuid) {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
@@ -36,6 +45,7 @@ public class TrainJobService {
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// 파라미터 조회
|
||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||
|
||||
if (trainRunRequest == null) {
|
||||
@@ -46,6 +56,7 @@ public class TrainJobService {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||
paramsMap.put("jobType", "TRAIN");
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
@@ -57,16 +68,66 @@ public class TrainJobService {
|
||||
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 재시작 버튼
|
||||
* 재시작
|
||||
*
|
||||
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
||||
*/
|
||||
@Transactional
|
||||
public Long restart(Long modelId) {
|
||||
return createNextAttempt(modelId, ResumeMode.NONE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 이어하기
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long resume(Long modelId) {
|
||||
return createNextAttempt(modelId, ResumeMode.REQUIRE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단
|
||||
*
|
||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||
*
|
||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long jobId = master.getCurrentAttemptId();
|
||||
if (jobId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
var job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||
|
||||
String containerName = job.getContainerName();
|
||||
|
||||
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||
if (containerName != null && !containerName.isBlank()) {
|
||||
dockerTrainService.killContainer(containerName);
|
||||
}
|
||||
|
||||
// 2) 상태 업데이트 (항상 수행)
|
||||
modelTrainJobCoreService.markCanceled(jobId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
|
||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
@@ -81,39 +142,72 @@ public class TrainJobService {
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
// 이전 params_json 재사용 (재현성)
|
||||
Map<String, Object> params = lastJob.getParamsJson();
|
||||
if (params == null || params.isEmpty()) {
|
||||
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
|
||||
}
|
||||
|
||||
// mode에 따라 resume 옵션 주입/제거
|
||||
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
|
||||
|
||||
if (mode == ResumeMode.NONE) {
|
||||
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
|
||||
nextParams.remove("resumeFrom");
|
||||
nextParams.remove("resume");
|
||||
} else if (mode == ResumeMode.REQUIRE) {
|
||||
// 체크포인트 탐지해서 resumeFrom 세팅
|
||||
String resumeFrom = findResumeFromOrNull(nextParams);
|
||||
if (resumeFrom == null) {
|
||||
throw new IllegalStateException("이어하기 체크포인트가 없습니다.");
|
||||
}
|
||||
nextParams.put("resumeFrom", resumeFrom);
|
||||
nextParams.put("resume", true);
|
||||
}
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId,
|
||||
nextAttemptNo,
|
||||
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
|
||||
ZonedDateTime.now());
|
||||
modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 버튼
|
||||
*
|
||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||
*
|
||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long attemptId = master.getCurrentAttemptId();
|
||||
if (attemptId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
private enum ResumeMode {
|
||||
NONE, // 새로 시작
|
||||
REQUIRE // 이어하기
|
||||
}
|
||||
|
||||
modelTrainJobCoreService.markCanceled(attemptId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null) return null;
|
||||
|
||||
Object out = paramsJson.get("outputFolder");
|
||||
if (out == null) return null;
|
||||
|
||||
String outputFolder = String.valueOf(out).trim(); // uuid
|
||||
if (outputFolder.isEmpty()) return null;
|
||||
|
||||
// 호스트 기준 경로
|
||||
Path outDir = Paths.get(responseDir, outputFolder);
|
||||
|
||||
Path last = outDir.resolve("last_checkpoint");
|
||||
if (!Files.isRegularFile(last)) return null;
|
||||
|
||||
try {
|
||||
String ckptFile = Files.readString(last).trim(); // epoch_10.pth
|
||||
if (ckptFile.isEmpty()) return null;
|
||||
|
||||
Path ckptHost = outDir.resolve(ckptFile);
|
||||
if (!Files.isRegularFile(ckptHost)) return null;
|
||||
|
||||
// 컨테이너 경로 반환
|
||||
return "/checkpoints/" + outputFolder + "/" + ckptFile;
|
||||
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
@@ -27,55 +29,82 @@ public class TrainJobWorker {
|
||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||
public void handle(ModelTrainJobQueuedEvent event) {
|
||||
|
||||
Long jobId = event.getJobId(); // record면 event.jobId()
|
||||
Long jobId = event.getJobId();
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
ModelTrainJobDto job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
Long modelId = job.getModelId();
|
||||
|
||||
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
|
||||
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
|
||||
// req가 null이면 실패 처리
|
||||
if (req == null) {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
|
||||
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
|
||||
if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
|
||||
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
|
||||
Long modelId = job.getModelId();
|
||||
Map<String, Object> params = job.getParamsJson();
|
||||
|
||||
// logPath/lockedBy는 너 환경에 맞게
|
||||
String logPath = null;
|
||||
String lockedBy = "TRAIN_WORKER";
|
||||
String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
|
||||
|
||||
// RUNNING 표시
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
|
||||
boolean isEval = "EVAL".equals(jobType);
|
||||
|
||||
String containerName = (isEval ? "eval-" : "train-") + jobId;
|
||||
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
|
||||
|
||||
try {
|
||||
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
|
||||
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
|
||||
// 아래 "추천 수정" 참고.
|
||||
TrainRunResult result = dockerTrainService.runTrainSync(req);
|
||||
TrainRunResult result;
|
||||
|
||||
if (isEval) {
|
||||
String uuid = String.valueOf(params.get("uuid"));
|
||||
int epoch = (int) params.get("epoch");
|
||||
|
||||
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
|
||||
result = dockerTrainService.runEvalSync(evalReq, containerName);
|
||||
|
||||
} else {
|
||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||
}
|
||||
|
||||
ModelTrainJobDto latest =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2Success(modelId);
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1Success(modelId);
|
||||
}
|
||||
|
||||
} else {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||
} else {
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
||||
|
||||
if ("EVAL".equals(params.get("jobType"))) {
|
||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||
} else {
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null || paramsJson.isEmpty()) {
|
||||
|
||||
Reference in New Issue
Block a user