테스트 실행 추가 #40
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
|
|||||||
|
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
@@ -20,12 +21,12 @@ public class ModelTrainJobCoreService {
|
|||||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
|
||||||
return modelTrainJobRepository.findLatestByModelId(modelId);
|
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Optional<ModelTrainJobEntity> findById(Long jobId) {
|
public Optional<ModelTrainJobDto> findById(Long jobId) {
|
||||||
return modelTrainJobRepository.findById(jobId);
|
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** QUEUED Job 생성 */
|
/** QUEUED Job 생성 */
|
||||||
@@ -95,7 +96,7 @@ public class ModelTrainJobCoreService {
|
|||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||||
|
|
||||||
job.setStatusCd("CANCELED");
|
job.setStatusCd("STOPPED");
|
||||||
job.setFinishedDttm(ZonedDateTime.now());
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional;
|
|||||||
@Service
|
@Service
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class ModelTrainMngCoreService {
|
public class ModelTrainMngCoreService {
|
||||||
|
|
||||||
private final ModelMngRepository modelMngRepository;
|
private final ModelMngRepository modelMngRepository;
|
||||||
private final ModelDatasetRepository modelDatasetRepository;
|
private final ModelDatasetRepository modelDatasetRepository;
|
||||||
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||||
@@ -323,7 +324,7 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 오류 처리(옵션) - Worker가 실패 시 호출 */
|
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -332,7 +333,25 @@ public class ModelTrainMngCoreService {
|
|||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
|
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||||
master.setLastError(errorMessage);
|
master.setLastError(errorMessage);
|
||||||
|
master.setUpdatedUid(userUtil.getId());
|
||||||
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
|
@Transactional
|
||||||
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
|
ModelMasterEntity master =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
|
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||||
|
master.setLastError(errorMessage);
|
||||||
|
master.setUpdatedUid(userUtil.getId());
|
||||||
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
@@ -358,4 +377,58 @@ public class ModelTrainMngCoreService {
|
|||||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||||
return modelMngRepository.findTrainRunRequest(modelId);
|
return modelMngRepository.findTrainRunRequest(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||||
|
ModelMasterEntity entity =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||||
|
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
entity.setCurrentAttemptId(jobId);
|
||||||
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||||
|
ModelMasterEntity entity =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||||
|
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
|
||||||
|
entity.setCurrentAttemptId(jobId);
|
||||||
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void markStep1Success(Long modelId) {
|
||||||
|
ModelMasterEntity entity =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
|
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||||
|
entity.setStep1EndDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void markStep2Success(Long modelId) {
|
||||||
|
ModelMasterEntity entity =
|
||||||
|
modelMngRepository
|
||||||
|
.findById(modelId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
|
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||||
|
entity.setStep2EndDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.kamco.cd.training.postgres.entity;
|
package com.kamco.cd.training.postgres.entity;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import jakarta.persistence.Column;
|
import jakarta.persistence.Column;
|
||||||
import jakarta.persistence.Entity;
|
import jakarta.persistence.Entity;
|
||||||
import jakarta.persistence.GeneratedValue;
|
import jakarta.persistence.GeneratedValue;
|
||||||
@@ -76,4 +77,19 @@ public class ModelTrainJobEntity {
|
|||||||
@Size(max = 100)
|
@Size(max = 100)
|
||||||
@Column(name = "locked_by", length = 100)
|
@Column(name = "locked_by", length = 100)
|
||||||
private String lockedBy;
|
private String lockedBy;
|
||||||
|
|
||||||
|
public ModelTrainJobDto toDto() {
|
||||||
|
return new ModelTrainJobDto(
|
||||||
|
this.id,
|
||||||
|
this.modelId,
|
||||||
|
this.attemptNo,
|
||||||
|
this.statusCd,
|
||||||
|
this.exitCode,
|
||||||
|
this.errorMessage,
|
||||||
|
this.containerName,
|
||||||
|
this.paramsJson,
|
||||||
|
this.queuedDttm,
|
||||||
|
this.startedDttm,
|
||||||
|
this.finishedDttm);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
|||||||
modelHyperParamEntity.contrastRange,
|
modelHyperParamEntity.contrastRange,
|
||||||
modelHyperParamEntity.saturationRange,
|
modelHyperParamEntity.saturationRange,
|
||||||
modelHyperParamEntity.hueDelta,
|
modelHyperParamEntity.hueDelta,
|
||||||
Expressions.nullExpression(Integer.class)))
|
Expressions.nullExpression(Integer.class),
|
||||||
|
Expressions.nullExpression(String.class)))
|
||||||
.from(modelMasterEntity)
|
.from(modelMasterEntity)
|
||||||
.leftJoin(modelHyperParamEntity)
|
.leftJoin(modelHyperParamEntity)
|
||||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||||
|
|||||||
@@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom {
|
|||||||
int findMaxAttemptNo(Long modelId);
|
int findMaxAttemptNo(Long modelId);
|
||||||
|
|
||||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||||
|
|
||||||
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,34 +1,43 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.train;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||||
|
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
import jakarta.persistence.EntityManager;
|
import jakarta.persistence.EntityManager;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
@Repository
|
@Repository
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||||
|
|
||||||
private final EntityManager em;
|
private final JPAQueryFactory queryFactory;
|
||||||
|
|
||||||
private JPAQueryFactory queryFactory() {
|
public ModelTrainJobRepositoryImpl(EntityManager em) {
|
||||||
return new JPAQueryFactory(em);
|
this.queryFactory = new JPAQueryFactory(em);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||||
@Override
|
@Override
|
||||||
public int findMaxAttemptNo(Long modelId) {
|
public int findMaxAttemptNo(Long modelId) {
|
||||||
return 0;
|
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||||
|
|
||||||
|
Integer max =
|
||||||
|
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||||
|
|
||||||
|
return max != null ? max : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
|
||||||
|
* 단순.
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||||
return Optional.empty();
|
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
ModelTrainJobEntity job =
|
||||||
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() {
|
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||||
return Optional.empty();
|
|
||||||
|
return Optional.ofNullable(job);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.train;
|
package com.kamco.cd.training.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||||
|
import com.kamco.cd.training.train.service.TestJobService;
|
||||||
import com.kamco.cd.training.train.service.TrainJobService;
|
import com.kamco.cd.training.train.service.TrainJobService;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
@@ -12,6 +13,7 @@ import io.swagger.v3.oas.annotations.tags.Tag;
|
|||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ import org.springframework.web.bind.annotation.RestController;
|
|||||||
public class TrainApiController {
|
public class TrainApiController {
|
||||||
|
|
||||||
private final TrainJobService trainJobService;
|
private final TrainJobService trainJobService;
|
||||||
|
private final TestJobService testJobService;
|
||||||
|
|
||||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||||
@ApiResponses(
|
@ApiResponses(
|
||||||
@@ -36,7 +39,7 @@ public class TrainApiController {
|
|||||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@RequestMapping("/run/{uuid}")
|
@PostMapping("/run/{uuid}")
|
||||||
public ApiResponseDto<String> run(
|
public ApiResponseDto<String> run(
|
||||||
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
@PathVariable
|
@PathVariable
|
||||||
@@ -45,4 +48,120 @@ public class TrainApiController {
|
|||||||
trainJobService.enqueue(modelId);
|
trainJobService.enqueue(modelId);
|
||||||
return ApiResponseDto.ok("ok");
|
return ApiResponseDto.ok("ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "재실행 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/restart/{uuid}")
|
||||||
|
public ApiResponseDto<String> restart(
|
||||||
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
Long jobId = trainJobService.restart(modelId);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "이어하기 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/resume/{uuid}")
|
||||||
|
public ApiResponseDto<String> resume(
|
||||||
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
Long jobId = trainJobService.resume(modelId);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "취소 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/cancel/{uuid}")
|
||||||
|
public ApiResponseDto<String> cancel(
|
||||||
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
trainJobService.cancel(modelId);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "test 실행", description = "test 실행 API")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "test 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/test/run/{epoch}/{uuid}")
|
||||||
|
public ApiResponseDto<String> run(
|
||||||
|
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
|
||||||
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
testJobService.enqueue(modelId, uuid, epoch);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
||||||
|
@ApiResponses(
|
||||||
|
value = {
|
||||||
|
@ApiResponse(
|
||||||
|
responseCode = "200",
|
||||||
|
description = "취소 성공",
|
||||||
|
content =
|
||||||
|
@Content(
|
||||||
|
mediaType = "application/json",
|
||||||
|
schema = @Schema(implementation = String.class))),
|
||||||
|
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||||
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
|
})
|
||||||
|
@PostMapping("/test/cancel/{uuid}")
|
||||||
|
public ApiResponseDto<String> cancelTest(
|
||||||
|
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
|
||||||
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
|
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||||
|
testJobService.cancel(modelId);
|
||||||
|
return ApiResponseDto.ok("ok");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.kamco.cd.training.train.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class EvalRunRequest {
|
||||||
|
private String uuid;
|
||||||
|
private int epoch; // best_changed_fscore_epoch_1.pth
|
||||||
|
private Integer timeoutSeconds;
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.kamco.cd.training.train.dto;
|
||||||
|
|
||||||
|
import java.time.ZonedDateTime;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class ModelTrainJobDto {
|
||||||
|
|
||||||
|
private Long id;
|
||||||
|
private Long modelId;
|
||||||
|
private Integer attemptNo;
|
||||||
|
private String statusCd;
|
||||||
|
private Integer exitCode;
|
||||||
|
private String errorMessage;
|
||||||
|
private String containerName;
|
||||||
|
private Map<String, Object> paramsJson;
|
||||||
|
private ZonedDateTime queuedDttm;
|
||||||
|
private ZonedDateTime startedDttm;
|
||||||
|
private ZonedDateTime finishedDttm;
|
||||||
|
}
|
||||||
@@ -79,4 +79,5 @@ public class TrainRunRequest {
|
|||||||
// 실행 타임아웃
|
// 실행 타임아웃
|
||||||
// ========================
|
// ========================
|
||||||
private Integer timeoutSeconds;
|
private Integer timeoutSeconds;
|
||||||
|
private String resumeFrom;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
@@ -7,7 +8,6 @@ import java.io.InputStreamReader;
|
|||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -40,53 +40,72 @@ public class DockerTrainService {
|
|||||||
private boolean ipcHost;
|
private boolean ipcHost;
|
||||||
|
|
||||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||||
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception {
|
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||||
|
|
||||||
// 실행 식별용 jobId 생성
|
|
||||||
String jobId = UUID.randomUUID().toString().substring(0, 8);
|
|
||||||
|
|
||||||
// 컨테이너 이름 생성 (중복 방지 목적)
|
|
||||||
String containerName = containerPrefix + "-" + jobId;
|
|
||||||
|
|
||||||
// docker run 명령어 조립
|
|
||||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||||
|
|
||||||
// 프로세스 실행
|
|
||||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||||
|
|
||||||
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
|
|
||||||
pb.redirectErrorStream(true);
|
pb.redirectErrorStream(true);
|
||||||
|
|
||||||
Process p = pb.start();
|
Process p = pb.start();
|
||||||
|
|
||||||
// 실행 로그 수집
|
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||||
StringBuilder log = new StringBuilder();
|
StringBuilder log = new StringBuilder();
|
||||||
|
Thread logThread =
|
||||||
|
new Thread(
|
||||||
|
() -> {
|
||||||
|
try (BufferedReader br =
|
||||||
|
new BufferedReader(
|
||||||
|
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
while ((line = br.readLine()) != null) {
|
||||||
|
synchronized (log) {
|
||||||
|
log.append(line).append('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train-log-" + containerName);
|
||||||
|
|
||||||
try (BufferedReader br =
|
logThread.setDaemon(true);
|
||||||
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
logThread.start();
|
||||||
|
|
||||||
String line;
|
int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||||
while ((line = br.readLine()) != null) {
|
|
||||||
log.append(line).append('\n');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 지정된 timeout 내에 종료 대기
|
|
||||||
int timeoutSeconds = 7200; // 기본 2시간
|
|
||||||
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
||||||
|
|
||||||
if (!finished) {
|
if (!finished) {
|
||||||
// 타임아웃 발생 시 컨테이너 강제 제거
|
// docker run 프로세스도 같이 끊어야 readLine이 풀림
|
||||||
|
p.destroy();
|
||||||
|
if (!p.waitFor(2, TimeUnit.SECONDS)) {
|
||||||
|
p.destroyForcibly();
|
||||||
|
}
|
||||||
killContainer(containerName);
|
killContainer(containerName);
|
||||||
|
|
||||||
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString());
|
String logs;
|
||||||
|
synchronized (log) {
|
||||||
|
logs = log.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new TrainRunResult(
|
||||||
|
null, // jobId (없으면 null)
|
||||||
|
containerName,
|
||||||
|
-1,
|
||||||
|
"TIMEOUT",
|
||||||
|
logs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 종료 코드 확인 (0=정상)
|
|
||||||
int exit = p.exitValue();
|
int exit = p.exitValue();
|
||||||
|
|
||||||
return new TrainRunResult(
|
// 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
|
||||||
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString());
|
logThread.join(500);
|
||||||
|
|
||||||
|
String logs;
|
||||||
|
synchronized (log) {
|
||||||
|
logs = log.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -205,6 +224,7 @@ public class DockerTrainService {
|
|||||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||||
addArg(c, "--hue-delta", req.getHueDelta());
|
addArg(c, "--hue-delta", req.getHueDelta());
|
||||||
|
|
||||||
|
addArg(c, "--resume-from", req.getResumeFrom());
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,7 +238,7 @@ public class DockerTrainService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** 컨테이너 강제 종료 및 제거 */
|
/** 컨테이너 강제 종료 및 제거 */
|
||||||
private void killContainer(String containerName) {
|
public void killContainer(String containerName) {
|
||||||
try {
|
try {
|
||||||
new ProcessBuilder("docker", "rm", "-f", containerName)
|
new ProcessBuilder("docker", "rm", "-f", containerName)
|
||||||
.redirectErrorStream(true)
|
.redirectErrorStream(true)
|
||||||
@@ -227,4 +247,100 @@ public class DockerTrainService {
|
|||||||
} catch (Exception ignored) {
|
} catch (Exception ignored) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
|
||||||
|
|
||||||
|
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||||
|
|
||||||
|
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||||
|
pb.redirectErrorStream(true);
|
||||||
|
|
||||||
|
Process p = pb.start();
|
||||||
|
|
||||||
|
StringBuilder log = new StringBuilder();
|
||||||
|
Thread logThread =
|
||||||
|
new Thread(
|
||||||
|
() -> {
|
||||||
|
try (BufferedReader br =
|
||||||
|
new BufferedReader(
|
||||||
|
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
while ((line = br.readLine()) != null) {
|
||||||
|
synchronized (log) {
|
||||||
|
log.append(line).append('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
logThread.setDaemon(true);
|
||||||
|
logThread.start();
|
||||||
|
|
||||||
|
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||||
|
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
if (!finished) {
|
||||||
|
p.destroyForcibly();
|
||||||
|
killContainer(containerName);
|
||||||
|
|
||||||
|
String logs;
|
||||||
|
synchronized (log) {
|
||||||
|
logs = log.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
|
||||||
|
}
|
||||||
|
|
||||||
|
int exit = p.exitValue();
|
||||||
|
logThread.join(500);
|
||||||
|
|
||||||
|
String logs;
|
||||||
|
synchronized (log) {
|
||||||
|
logs = log.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
|
||||||
|
|
||||||
|
String uuid = req.getUuid();
|
||||||
|
Integer epoch = req.getEpoch();
|
||||||
|
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
|
||||||
|
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
|
||||||
|
|
||||||
|
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
|
||||||
|
|
||||||
|
List<String> c = new ArrayList<>();
|
||||||
|
|
||||||
|
c.add("docker");
|
||||||
|
c.add("run");
|
||||||
|
c.add("--name");
|
||||||
|
c.add(containerName);
|
||||||
|
c.add("--rm");
|
||||||
|
|
||||||
|
c.add("--gpus");
|
||||||
|
c.add("all");
|
||||||
|
if (ipcHost) c.add("--ipc=host");
|
||||||
|
c.add("--shm-size=" + shmSize);
|
||||||
|
|
||||||
|
c.add("-v");
|
||||||
|
c.add(requestDir + ":/data");
|
||||||
|
c.add("-v");
|
||||||
|
c.add(responseDir + ":/checkpoints");
|
||||||
|
|
||||||
|
c.add(image);
|
||||||
|
|
||||||
|
c.add("python");
|
||||||
|
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
|
||||||
|
|
||||||
|
c.add("--dataset_dir");
|
||||||
|
c.add("/data/" + uuid);
|
||||||
|
|
||||||
|
c.add("--model");
|
||||||
|
c.add("/checkpoints/" + uuid + "/" + modelFile);
|
||||||
|
|
||||||
|
return c;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||||
|
import java.time.ZonedDateTime;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.context.ApplicationEventPublisher;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
@Transactional(readOnly = true)
|
||||||
|
public class TestJobService {
|
||||||
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
private final DockerTrainService dockerTrainService;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||||
|
|
||||||
|
// 마스터 확인
|
||||||
|
modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||||
|
params.put("jobType", "EVAL");
|
||||||
|
params.put("uuid", uuid);
|
||||||
|
params.put("epoch", epoch);
|
||||||
|
|
||||||
|
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||||
|
|
||||||
|
Long jobId =
|
||||||
|
modelTrainJobCoreService.createQueuedJob(
|
||||||
|
modelId, nextAttemptNo, params, ZonedDateTime.now());
|
||||||
|
|
||||||
|
// step2 시작으로 마킹
|
||||||
|
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||||
|
|
||||||
|
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||||
|
return jobId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public void cancel(Long modelId) {
|
||||||
|
|
||||||
|
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
Long jobId = master.getCurrentAttemptId();
|
||||||
|
if (jobId == null) {
|
||||||
|
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||||
|
}
|
||||||
|
|
||||||
|
var job =
|
||||||
|
modelTrainJobCoreService
|
||||||
|
.findById(jobId)
|
||||||
|
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||||
|
|
||||||
|
String containerName = job.getContainerName();
|
||||||
|
|
||||||
|
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||||
|
if (containerName != null && !containerName.isBlank()) {
|
||||||
|
dockerTrainService.killContainer(containerName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 상태 업데이트 (항상 수행)
|
||||||
|
modelTrainJobCoreService.markCanceled(jobId);
|
||||||
|
modelTrainMngCoreService.markStopped(modelId);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,10 +7,14 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
|||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.ApplicationEventPublisher;
|
import org.springframework.context.ApplicationEventPublisher;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
@@ -22,9 +26,14 @@ public class TrainJobService {
|
|||||||
|
|
||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
private final DockerTrainService dockerTrainService;
|
||||||
private final ObjectMapper objectMapper;
|
private final ObjectMapper objectMapper;
|
||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
|
|
||||||
|
// 학습 결과가 저장될 호스트 디렉토리
|
||||||
|
@Value("${train.docker.responseDir}")
|
||||||
|
private String responseDir;
|
||||||
|
|
||||||
public Long getModelIdByUuid(UUID uuid) {
|
public Long getModelIdByUuid(UUID uuid) {
|
||||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||||
}
|
}
|
||||||
@@ -36,6 +45,7 @@ public class TrainJobService {
|
|||||||
// 마스터 존재 확인(없으면 예외)
|
// 마스터 존재 확인(없으면 예외)
|
||||||
modelTrainMngCoreService.findModelById(modelId);
|
modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
// 파라미터 조회
|
||||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||||
|
|
||||||
if (trainRunRequest == null) {
|
if (trainRunRequest == null) {
|
||||||
@@ -46,6 +56,7 @@ public class TrainJobService {
|
|||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||||
|
paramsMap.put("jobType", "TRAIN");
|
||||||
|
|
||||||
Long jobId =
|
Long jobId =
|
||||||
modelTrainJobCoreService.createQueuedJob(
|
modelTrainJobCoreService.createQueuedJob(
|
||||||
@@ -57,16 +68,66 @@ public class TrainJobService {
|
|||||||
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
||||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||||
|
|
||||||
|
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||||
return jobId;
|
return jobId;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 재시작 버튼
|
* 재시작
|
||||||
*
|
*
|
||||||
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
||||||
*/
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long restart(Long modelId) {
|
public Long restart(Long modelId) {
|
||||||
|
return createNextAttempt(modelId, ResumeMode.NONE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 이어하기
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public Long resume(Long modelId) {
|
||||||
|
return createNextAttempt(modelId, ResumeMode.REQUIRE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 중단
|
||||||
|
*
|
||||||
|
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||||
|
*
|
||||||
|
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public void cancel(Long modelId) {
|
||||||
|
|
||||||
|
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
|
Long jobId = master.getCurrentAttemptId();
|
||||||
|
if (jobId == null) {
|
||||||
|
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||||
|
}
|
||||||
|
|
||||||
|
var job =
|
||||||
|
modelTrainJobCoreService
|
||||||
|
.findById(jobId)
|
||||||
|
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||||
|
|
||||||
|
String containerName = job.getContainerName();
|
||||||
|
|
||||||
|
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||||
|
if (containerName != null && !containerName.isBlank()) {
|
||||||
|
dockerTrainService.killContainer(containerName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 상태 업데이트 (항상 수행)
|
||||||
|
modelTrainJobCoreService.markCanceled(jobId);
|
||||||
|
modelTrainMngCoreService.markStopped(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||||
|
|
||||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||||
|
|
||||||
@@ -81,39 +142,72 @@ public class TrainJobService {
|
|||||||
|
|
||||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||||
|
|
||||||
|
// 이전 params_json 재사용 (재현성)
|
||||||
|
Map<String, Object> params = lastJob.getParamsJson();
|
||||||
|
if (params == null || params.isEmpty()) {
|
||||||
|
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// mode에 따라 resume 옵션 주입/제거
|
||||||
|
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
|
||||||
|
|
||||||
|
if (mode == ResumeMode.NONE) {
|
||||||
|
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
|
||||||
|
nextParams.remove("resumeFrom");
|
||||||
|
nextParams.remove("resume");
|
||||||
|
} else if (mode == ResumeMode.REQUIRE) {
|
||||||
|
// 체크포인트 탐지해서 resumeFrom 세팅
|
||||||
|
String resumeFrom = findResumeFromOrNull(nextParams);
|
||||||
|
if (resumeFrom == null) {
|
||||||
|
throw new IllegalStateException("이어하기 체크포인트가 없습니다.");
|
||||||
|
}
|
||||||
|
nextParams.put("resumeFrom", resumeFrom);
|
||||||
|
nextParams.put("resume", true);
|
||||||
|
}
|
||||||
|
|
||||||
Long jobId =
|
Long jobId =
|
||||||
modelTrainJobCoreService.createQueuedJob(
|
modelTrainJobCoreService.createQueuedJob(
|
||||||
modelId,
|
modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
|
||||||
nextAttemptNo,
|
|
||||||
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
|
|
||||||
ZonedDateTime.now());
|
|
||||||
|
|
||||||
modelTrainMngCoreService.clearLastError(modelId);
|
modelTrainMngCoreService.clearLastError(modelId);
|
||||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||||
|
|
||||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||||
|
|
||||||
return jobId;
|
return jobId;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private enum ResumeMode {
|
||||||
* 중단 버튼
|
NONE, // 새로 시작
|
||||||
*
|
REQUIRE // 이어하기
|
||||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
}
|
||||||
*
|
|
||||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
|
||||||
*/
|
|
||||||
@Transactional
|
|
||||||
public void cancel(Long modelId) {
|
|
||||||
|
|
||||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||||
|
if (paramsJson == null) return null;
|
||||||
|
|
||||||
Long attemptId = master.getCurrentAttemptId();
|
Object out = paramsJson.get("outputFolder");
|
||||||
if (attemptId == null) {
|
if (out == null) return null;
|
||||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
|
||||||
|
String outputFolder = String.valueOf(out).trim(); // uuid
|
||||||
|
if (outputFolder.isEmpty()) return null;
|
||||||
|
|
||||||
|
// 호스트 기준 경로
|
||||||
|
Path outDir = Paths.get(responseDir, outputFolder);
|
||||||
|
|
||||||
|
Path last = outDir.resolve("last_checkpoint");
|
||||||
|
if (!Files.isRegularFile(last)) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
String ckptFile = Files.readString(last).trim(); // epoch_10.pth
|
||||||
|
if (ckptFile.isEmpty()) return null;
|
||||||
|
|
||||||
|
Path ckptHost = outDir.resolve(ckptFile);
|
||||||
|
if (!Files.isRegularFile(ckptHost)) return null;
|
||||||
|
|
||||||
|
// 컨테이너 경로 반환
|
||||||
|
return "/checkpoints/" + outputFolder + "/" + ckptFile;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
modelTrainJobCoreService.markCanceled(attemptId);
|
|
||||||
modelTrainMngCoreService.markStopped(modelId);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package com.kamco.cd.training.train.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||||
@@ -27,53 +29,80 @@ public class TrainJobWorker {
|
|||||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||||
public void handle(ModelTrainJobQueuedEvent event) {
|
public void handle(ModelTrainJobQueuedEvent event) {
|
||||||
|
|
||||||
Long jobId = event.getJobId(); // record면 event.jobId()
|
Long jobId = event.getJobId();
|
||||||
|
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobDto job =
|
||||||
modelTrainJobCoreService
|
modelTrainJobCoreService
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||||
|
|
||||||
Long modelId = job.getModelId();
|
if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
|
||||||
|
|
||||||
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
|
|
||||||
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
|
|
||||||
// req가 null이면 실패 처리
|
|
||||||
if (req == null) {
|
|
||||||
modelTrainJobCoreService.markFailed(
|
|
||||||
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
|
|
||||||
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리
|
Long modelId = job.getModelId();
|
||||||
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
|
||||||
// logPath/lockedBy는 너 환경에 맞게
|
String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
|
||||||
String logPath = null;
|
|
||||||
String lockedBy = "TRAIN_WORKER";
|
|
||||||
|
|
||||||
// RUNNING 표시
|
boolean isEval = "EVAL".equals(jobType);
|
||||||
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
|
|
||||||
|
String containerName = (isEval ? "eval-" : "train-") + jobId;
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서
|
TrainRunResult result;
|
||||||
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
|
|
||||||
// 아래 "추천 수정" 참고.
|
if (isEval) {
|
||||||
TrainRunResult result = dockerTrainService.runTrainSync(req);
|
String uuid = String.valueOf(params.get("uuid"));
|
||||||
|
int epoch = (int) params.get("epoch");
|
||||||
|
|
||||||
|
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
|
||||||
|
result = dockerTrainService.runEvalSync(evalReq, containerName);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||||
|
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelTrainJobDto latest =
|
||||||
|
modelTrainJobCoreService
|
||||||
|
.findById(jobId)
|
||||||
|
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||||
|
|
||||||
|
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (result.getExitCode() == 0) {
|
if (result.getExitCode() == 0) {
|
||||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||||
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
|
|
||||||
|
if (isEval) {
|
||||||
|
modelTrainMngCoreService.markStep2Success(modelId);
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markStep1Success(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
modelTrainJobCoreService.markFailed(
|
modelTrainJobCoreService.markFailed(
|
||||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
|
||||||
|
if (isEval) {
|
||||||
|
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
||||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
|
||||||
|
if ("EVAL".equals(params.get("jobType"))) {
|
||||||
|
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user