diff --git a/src/main/java/com/kamco/cd/training/KamcoTrainingApplication.java b/src/main/java/com/kamco/cd/training/KamcoTrainingApplication.java index e14139c..0ece266 100644 --- a/src/main/java/com/kamco/cd/training/KamcoTrainingApplication.java +++ b/src/main/java/com/kamco/cd/training/KamcoTrainingApplication.java @@ -2,8 +2,10 @@ package com.kamco.cd.training; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; +@EnableAsync @SpringBootApplication @EnableScheduling public class KamcoTrainingApplication { diff --git a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java index 36b20f0..fa7f983 100644 --- a/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java +++ b/src/main/java/com/kamco/cd/training/model/dto/ModelTrainMngDto.java @@ -40,6 +40,7 @@ public class ModelTrainMngDto { private String statusCd; private String trainType; private String modelNo; + private Long currentAttemptId; public String getStatusName() { if (this.statusCd == null || this.statusCd.isBlank()) return null; diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index 5e22404..2438c0d 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -12,6 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq; import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.service.TrainJobService; import java.util.List; import java.util.UUID; import lombok.RequiredArgsConstructor; @@ -27,6 +28,7 @@ import org.springframework.transaction.annotation.Transactional; @Slf4j public class ModelTrainMngService { + private final TrainJobService trainJobService; private final ModelTrainMngCoreService modelTrainMngCoreService; private final HyperParamCoreService hyperParamCoreService; @@ -62,8 +64,8 @@ public class ModelTrainMngService { HyperParamDto.Basic hyper = new HyperParamDto.Basic(); // 전이 학습은 모델 선택 필수 - if (req.getTrainType().equals(TrainType.TRANSFER.getId())) { - if (req.getBeforeModelId() != null) { + if (TrainType.TRANSFER.getId().equals(req.getTrainType())) { + if (req.getBeforeModelId() == null) { throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요."); } } @@ -87,6 +89,11 @@ public class ModelTrainMngService { // 모델 config 저장 modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig()); + + // 저장 다 끝난 뒤에 job enqueue + if (Boolean.TRUE.equals(req.getIsStart())) { + trainJobService.enqueue(modelId); // job 저장 + 이벤트 발행(실행은 AFTER_COMMIT) + } } /** diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java new file mode 100644 index 0000000..f2440a6 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -0,0 +1,101 @@ +package com.kamco.cd.training.postgres.core; + +import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; +import java.time.ZonedDateTime; +import java.util.Map; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional(readOnly = true) +public class ModelTrainJobCoreService { + + private final ModelTrainJobRepository modelTrainJobRepository; + + public int findMaxAttemptNo(Long modelId) { + return modelTrainJobRepository.findMaxAttemptNo(modelId); + } + + public Optional findLatestByModelId(Long modelId) { + return modelTrainJobRepository.findLatestByModelId(modelId); + } + + public Optional findById(Long jobId) { + return modelTrainJobRepository.findById(jobId); + } + + /** QUEUED Job 생성 */ + @Transactional + public Long createQueuedJob( + Long modelId, int attemptNo, Map paramsJson, ZonedDateTime queuedDttm) { + + ModelTrainJobEntity job = new ModelTrainJobEntity(); + job.setModelId(modelId); + job.setAttemptNo(attemptNo); + job.setStatusCd("QUEUED"); + job.setParamsJson(paramsJson); + job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now()); + + modelTrainJobRepository.save(job); + return job.getId(); + } + + /** 실행 시작 처리 */ + @Transactional + public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + job.setStatusCd("RUNNING"); + job.setContainerName(containerName); + job.setLogPath(logPath); + job.setStartedDttm(ZonedDateTime.now()); + job.setLockedDttm(ZonedDateTime.now()); + job.setLockedBy(lockedBy); + } + + /** 성공 처리 */ + @Transactional + public void markSuccess(Long jobId, int exitCode) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + job.setStatusCd("SUCCESS"); + job.setExitCode(exitCode); + job.setFinishedDttm(ZonedDateTime.now()); + } + + /** 실패 처리 */ + @Transactional + public void markFailed(Long jobId, Integer exitCode, String errorMessage) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + job.setStatusCd("FAILED"); + job.setExitCode(exitCode); + job.setErrorMessage(errorMessage); + job.setFinishedDttm(ZonedDateTime.now()); + } + + /** 취소 처리 */ + @Transactional + public void markCanceled(Long jobId) { + ModelTrainJobEntity job = + modelTrainJobRepository + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + job.setStatusCd("CANCELED"); + job.setFinishedDttm(ZonedDateTime.now()); + } +} diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 4ccb9e4..d3028fb 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository; import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository; import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository; import com.kamco.cd.training.postgres.repository.model.ModelMngRepository; +import com.kamco.cd.training.train.dto.TrainRunRequest; import java.time.ZonedDateTime; import java.util.List; import java.util.UUID; @@ -30,6 +31,7 @@ import lombok.RequiredArgsConstructor; import org.springframework.data.domain.Page; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; @Service @RequiredArgsConstructor @@ -213,6 +215,20 @@ public class ModelTrainMngCoreService { } } + /** + * uuid로 model id 조회 + * + * @param uuid + * @return + */ + public Long findModelIdByUuid(UUID uuid) { + ModelMasterEntity entity = + modelMngRepository + .findByUuid(uuid) + .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); + return entity.getId(); + } + /** * 모델학습 아이디로 config정보 조회 * @@ -245,4 +261,101 @@ public class ModelTrainMngCoreService { public List getDatasetSelectG2G3List(DatasetReq req) { return datasetRepository.getDatasetSelectG2G3List(req); } + + /** + * 모델관리 조회 + * + * @param id + * @return + */ + public ModelTrainMngDto.Basic findModelById(Long id) { + ModelMasterEntity entity = + modelMngRepository + .findById(id) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + id)); + return entity.toDto(); + } + + /** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */ + @Transactional + public void markInProgress(Long modelId, Long jobId) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); + master.setCurrentAttemptId(jobId); + + // 필요하면 시작시간도 여기서 찍어줌 + } + + /** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */ + @Transactional + public void clearLastError(Long modelId) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setLastError(null); + } + + /** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */ + @Transactional + public void markStopped(Long modelId) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setStatusCd(TrainStatusType.STOPPED.getId()); + } + + /** 완료 처리(옵션) - Worker가 성공 시 호출 */ + @Transactional + public void markCompleted(Long modelId) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setStatusCd(TrainStatusType.COMPLETED.getId()); + } + + /** 오류 처리(옵션) - Worker가 실패 시 호출 */ + @Transactional + public void markError(Long modelId, String errorMessage) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + master.setStatusCd(TrainStatusType.ERROR.getId()); + master.setLastError(errorMessage); + } + + @Transactional + public void markSuccess(Long modelId) { + ModelMasterEntity master = + modelMngRepository + .findById(modelId) + .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); + + // 모델 상태 완료 처리 + master.setStatusCd(TrainStatusType.COMPLETED.getId()); + + // (선택) 마지막 에러 메시지 비우기 + master.setLastError(null); + } + + /** + * 학습 실행에 필요한 파라미터 조회 + * + * @param modelId + * @return + */ + public TrainRunRequest findTrainRunRequest(Long modelId) { + return modelMngRepository.findTrainRunRequest(modelId); + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java index 05311dd..61d9b75 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelMasterEntity.java @@ -97,6 +97,12 @@ public class ModelMasterEntity { @Column(name = "step2_metric_save_yn") private Boolean step2MetricSaveYn; + @Column(name = "current_attempt_id") + private Long currentAttemptId; + + @Column(name = "last_error") + private String lastError; + public ModelTrainMngDto.Basic toDto() { return new ModelTrainMngDto.Basic( this.id, @@ -111,6 +117,7 @@ public class ModelMasterEntity { this.step2State, this.statusCd, this.trainType, - this.modelNo); + this.modelNo, + this.currentAttemptId); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java new file mode 100644 index 0000000..504d2c4 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java @@ -0,0 +1,79 @@ +package com.kamco.cd.training.postgres.entity; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import java.time.ZonedDateTime; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; +import org.hibernate.annotations.ColumnDefault; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.type.SqlTypes; + +@Getter +@Setter +@Entity +@Table(name = "tb_model_train_job") +public class ModelTrainJobEntity { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Column(name = "id", nullable = false) + private Long id; + + @NotNull + @Column(name = "model_id", nullable = false) + private Long modelId; + + @NotNull + @Column(name = "attempt_no", nullable = false) + private Integer attemptNo; + + @Size(max = 30) + @NotNull + @Column(name = "status_cd", nullable = false, length = 30) + private String statusCd; + + @NotNull + @Column(name = "params_json", nullable = false) + @JdbcTypeCode(SqlTypes.JSON) + private Map paramsJson; + + @Size(max = 200) + @Column(name = "container_name", length = 200) + private String containerName; + + @Size(max = 500) + @Column(name = "log_path", length = 500) + private String logPath; + + @Column(name = "exit_code") + private Integer exitCode; + + @Size(max = 2000) + @Column(name = "error_message", length = 2000) + private String errorMessage; + + @ColumnDefault("now()") + @Column(name = "queued_dttm") + private ZonedDateTime queuedDttm; + + @Column(name = "started_dttm") + private ZonedDateTime startedDttm; + + @Column(name = "finished_dttm") + private ZonedDateTime finishedDttm; + + @Column(name = "locked_dttm") + private ZonedDateTime lockedDttm; + + @Size(max = 100) + @Column(name = "locked_by", length = 100) + private String lockedBy; +} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java index d664435..2dc94e0 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryCustom.java @@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model; import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; +import com.kamco.cd.training.train.dto.TrainRunRequest; import java.util.Optional; import java.util.UUID; import org.springframework.data.domain.Page; @@ -19,4 +20,6 @@ public interface ModelMngRepositoryCustom { Optional findByUuid(UUID uuid); Optional findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn); + + TrainRunRequest findTrainRunRequest(Long modelId); } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index 12297e8..94a578e 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -1,10 +1,15 @@ package com.kamco.cd.training.postgres.repository.model; +import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity; +import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.postgres.entity.ModelMasterEntity; +import com.kamco.cd.training.train.dto.TrainRunRequest; import com.querydsl.core.BooleanBuilder; +import com.querydsl.core.types.Projections; +import com.querydsl.core.types.dsl.Expressions; import com.querydsl.jpa.impl.JPAQueryFactory; import java.util.List; import java.util.Optional; @@ -82,4 +87,60 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { public Optional findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) { return Optional.empty(); } + + @Override + public TrainRunRequest findTrainRunRequest(Long modelId) { + queryFactory + .select( + Projections.constructor( + TrainRunRequest.class, + modelMasterEntity.uuid, // datasetFolder + modelMasterEntity.uuid, // outputFolder + modelHyperParamEntity.inputSize, + modelHyperParamEntity.cropSize, + modelHyperParamEntity.batchSize, + modelHyperParamEntity.gpuIds, + modelHyperParamEntity.gpuCnt, + modelHyperParamEntity.learningRate, + modelHyperParamEntity.backbone, + modelHyperParamEntity.epochCnt, + modelHyperParamEntity.trainNumWorkers, + modelHyperParamEntity.valNumWorkers, + modelHyperParamEntity.testNumWorkers, + modelHyperParamEntity.trainShuffle, + modelHyperParamEntity.trainPersistent, + modelHyperParamEntity.valPersistent, + modelHyperParamEntity.dropPathRate, + modelHyperParamEntity.frozenStages, + modelHyperParamEntity.neckPolicy, + modelHyperParamEntity.classWeight, + modelHyperParamEntity.decoderChannels, + modelHyperParamEntity.weightDecay, + modelHyperParamEntity.layerDecayRate, + modelHyperParamEntity.ignoreIndex, + modelHyperParamEntity.ddpFindUnusedParams, + modelHyperParamEntity.numLayers, + modelHyperParamEntity.metrics, + modelHyperParamEntity.saveBest, + modelHyperParamEntity.saveBestRule, + modelHyperParamEntity.valInterval, + modelHyperParamEntity.logInterval, + modelHyperParamEntity.visInterval, + modelHyperParamEntity.rotProb, + modelHyperParamEntity.rotDegree, + modelHyperParamEntity.flipProb, + modelHyperParamEntity.exchangeProb, + modelHyperParamEntity.brightnessDelta, + modelHyperParamEntity.contrastRange, + modelHyperParamEntity.saturationRange, + modelHyperParamEntity.hueDelta, + Expressions.nullExpression(Integer.class))) + .from(modelMasterEntity) + .leftJoin(modelHyperParamEntity) + .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) + .leftJoin(modelConfigEntity) + .on(modelConfigEntity.model.id.eq(modelMasterEntity.id)) + .fetchOne(); + return null; + } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepository.java new file mode 100644 index 0000000..0d04bb1 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepository.java @@ -0,0 +1,7 @@ +package com.kamco.cd.training.postgres.repository.train; + +import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import org.springframework.data.jpa.repository.JpaRepository; + +public interface ModelTrainJobRepository + extends JpaRepository, ModelTrainJobRepositoryCustom {} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java new file mode 100644 index 0000000..7b9cf1d --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryCustom.java @@ -0,0 +1,12 @@ +package com.kamco.cd.training.postgres.repository.train; + +import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import java.util.Optional; + +public interface ModelTrainJobRepositoryCustom { + int findMaxAttemptNo(Long modelId); + + Optional findLatestByModelId(Long modelId); + + Optional pickQueuedForUpdate(); +} diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java new file mode 100644 index 0000000..9fa924a --- /dev/null +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainJobRepositoryImpl.java @@ -0,0 +1,34 @@ +package com.kamco.cd.training.postgres.repository.train; + +import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import com.querydsl.jpa.impl.JPAQueryFactory; +import jakarta.persistence.EntityManager; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Repository; + +@Repository +@RequiredArgsConstructor +public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom { + + private final EntityManager em; + + private JPAQueryFactory queryFactory() { + return new JPAQueryFactory(em); + } + + @Override + public int findMaxAttemptNo(Long modelId) { + return 0; + } + + @Override + public Optional findLatestByModelId(Long modelId) { + return Optional.empty(); + } + + @Override + public Optional pickQueuedForUpdate() { + return Optional.empty(); + } +} diff --git a/src/main/java/com/kamco/cd/training/train/TrainApiController.java b/src/main/java/com/kamco/cd/training/train/TrainApiController.java new file mode 100644 index 0000000..8c8bb94 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/TrainApiController.java @@ -0,0 +1,48 @@ +package com.kamco.cd.training.train; + +import com.kamco.cd.training.config.api.ApiResponseDto; +import com.kamco.cd.training.train.service.TrainJobService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.responses.ApiResponses; +import io.swagger.v3.oas.annotations.tags.Tag; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API") +@RequiredArgsConstructor +@RestController +@RequestMapping("/api/train") +public class TrainApiController { + + private final TrainJobService trainJobService; + + @Operation(summary = "학습 실행", description = "학습 실행 API") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "실행 성공", + content = + @Content( + mediaType = "application/json", + schema = @Schema(implementation = String.class))), + @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), + @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) + }) + @RequestMapping("/run/{uuid}") + public ApiResponseDto run( + @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") + @PathVariable + UUID uuid) { + Long modelId = trainJobService.getModelIdByUuid(uuid); + trainJobService.enqueue(modelId); + return ApiResponseDto.ok("ok"); + } +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobQueuedEvent.java b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobQueuedEvent.java new file mode 100644 index 0000000..9fe58a5 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobQueuedEvent.java @@ -0,0 +1,15 @@ +package com.kamco.cd.training.train.dto; + +/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */ +public class ModelTrainJobQueuedEvent { + + private final Long jobId; + + public ModelTrainJobQueuedEvent(Long jobId) { + this.jobId = jobId; + } + + public Long getJobId() { + return jobId; + } +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java new file mode 100644 index 0000000..4fc7a70 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java @@ -0,0 +1,82 @@ +package com.kamco.cd.training.train.dto; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +public class TrainRunRequest { + + // ======================== + // 기본 + // ======================== + private String datasetFolder; + private String outputFolder; + private String inputSize; + private String cropSize; + private Integer batchSize; + private String gpuIds; + private Integer gpus; + private Double learningRate; + private String backbone; + private Integer epochs; + + // ======================== + // Data + // ======================== + private Integer trainNumWorkers; + private Integer valNumWorkers; + private Integer testNumWorkers; + private Boolean trainShuffle; + private Boolean trainPersistent; + private Boolean valPersistent; + + // ======================== + // Model Architecture + // ======================== + private Double dropPathRate; + private Integer frozenStages; + private String neckPolicy; + private String classWeight; + private String decoderChannels; + + // ======================== + // Loss & Optimization + // ======================== + private Double weightDecay; + private Double layerDecayRate; + private Integer ignoreIndex; + private Boolean ddpFindUnusedParams; + private Integer numLayers; + + // ======================== + // Evaluation + // ======================== + private String metrics; + private String saveBest; + private String saveBestRule; + private Integer valInterval; + private Integer logInterval; + private Integer visInterval; + + // ======================== + // Augmentation + // ======================== + private Double rotProb; + private String rotDegree; + private Double flipProb; + private Double exchangeProb; + private Integer brightnessDelta; + private String contrastRange; + private String saturationRange; + private Integer hueDelta; + + // ======================== + // 실행 타임아웃 + // ======================== + private Integer timeoutSeconds; +} diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunResult.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunResult.java new file mode 100644 index 0000000..96527bb --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunResult.java @@ -0,0 +1,20 @@ +package com.kamco.cd.training.train.dto; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +/** 학습 실행 결과 반환 객체 */ +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +public class TrainRunResult { + + private String jobId; + private String containerName; + private int exitCode; + private String status; + private String logs; +} diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java new file mode 100644 index 0000000..cfd6364 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -0,0 +1,230 @@ +package com.kamco.cd.training.train.service; + +import com.kamco.cd.training.train.dto.TrainRunRequest; +import com.kamco.cd.training.train.dto.TrainRunResult; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +@Service +public class DockerTrainService { + + // 실행할 Docker 이미지명 + @Value("${train.docker.image}") + private String image; + + // 학습 요청 데이터가 위치한 호스트 디렉토리 + @Value("${train.docker.requestDir}") + private String requestDir; + + // 학습 결과가 저장될 호스트 디렉토리 + @Value("${train.docker.responseDir}") + private String responseDir; + + // 컨테이너 이름 prefix + @Value("${train.docker.containerPrefix}") + private String containerPrefix; + + // 공유메모리 사이즈 설정 (대용량 학습시 필요) + @Value("${train.docker.shmSize:16g}") + private String shmSize; + + // IPC host 사용 여부 + @Value("${train.docker.ipcHost:true}") + private boolean ipcHost; + + /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ + public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception { + + // 실행 식별용 jobId 생성 + String jobId = UUID.randomUUID().toString().substring(0, 8); + + // 컨테이너 이름 생성 (중복 방지 목적) + String containerName = containerPrefix + "-" + jobId; + + // docker run 명령어 조립 + List cmd = buildDockerRunCommand(containerName, req); + + // 프로세스 실행 + ProcessBuilder pb = new ProcessBuilder(cmd); + + // stderr를 stdout으로 합쳐서 한 스트림으로 처리 + pb.redirectErrorStream(true); + + Process p = pb.start(); + + // 실행 로그 수집 + StringBuilder log = new StringBuilder(); + + try (BufferedReader br = + new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + + String line; + while ((line = br.readLine()) != null) { + log.append(line).append('\n'); + } + } + + // 지정된 timeout 내에 종료 대기 + int timeoutSeconds = 7200; // 기본 2시간 + boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS); + + if (!finished) { + // 타임아웃 발생 시 컨테이너 강제 제거 + killContainer(containerName); + + return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString()); + } + + // 종료 코드 확인 (0=정상) + int exit = p.exitValue(); + + return new TrainRunResult( + jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString()); + } + + /** + * docker run 명령어 리스트 구성 - 환경변수 설정 - GPU 옵션 설정 - 볼륨 마운트 - 컨테이너 내부 python 실행 명령 구성 - 요청값이 + * null/blank면 해당 옵션은 "아예 생략" + */ + private List buildDockerRunCommand(String containerName, TrainRunRequest req) { + + List c = new ArrayList<>(); + + c.add("docker"); + c.add("run"); + + // 컨테이너 이름 지정 + c.add("--name"); + c.add(containerName); + + // 실행 종료 시 자동 삭제 + c.add("--rm"); + + // 환경변수 설정 + c.add("-e"); + c.add("OPENCV_LOG_LEVEL=ERROR"); + c.add("-e"); + c.add("NCCL_DEBUG=INFO"); + c.add("-e"); + c.add("NCCL_IB_DISABLE=1"); + c.add("-e"); + c.add("NCCL_P2P_DISABLE=0"); + c.add("-e"); + c.add("NCCL_SOCKET_IFNAME=eth0"); + + // GPU 전체 사용 + c.add("--gpus"); + c.add("all"); + + // IPC host 사용 여부 + if (ipcHost) { + c.add("--ipc=host"); + } + + // 공유메모리 설정 + c.add("--shm-size=" + shmSize); + + // 메모리 관련 ulimit 설정 + c.add("--ulimit"); + c.add("memlock=-1"); + c.add("--ulimit"); + c.add("stack=67108864"); + + // 요청/결과 디렉토리 볼륨 마운트 + c.add("-v"); + c.add(requestDir + ":/data"); + c.add("-v"); + c.add(responseDir + ":/checkpoints"); + + // 표준입력 유지 (-it 대신 -i만 사용) + c.add("-i"); + + // 사용할 이미지 + c.add(image); + + // ===== 컨테이너 내부 실행 명령 ===== + c.add("python"); + c.add("/workspace/change-detection-code/train_wrapper.py"); + + // ===== 기본 파라미터 ===== + addArg(c, "--dataset-folder", req.getDatasetFolder()); + addArg(c, "--output-folder", req.getOutputFolder()); + addArg(c, "--input-size", req.getInputSize()); + addArg(c, "--crop-size", req.getCropSize()); + addArg(c, "--batch-size", req.getBatchSize()); + addArg(c, "--gpu-ids", req.getGpuIds()); + // addArg(c, "--gpus", req.getGpus()); + addArg(c, "--lr", req.getLearningRate()); + addArg(c, "--backbone", req.getBackbone()); + addArg(c, "--epochs", req.getEpochs()); + + // ===== Data ===== + addArg(c, "--train-num-workers", req.getTrainNumWorkers()); + addArg(c, "--val-num-workers", req.getValNumWorkers()); + addArg(c, "--test-num-workers", req.getTestNumWorkers()); + addArg(c, "--train-shuffle", req.getTrainShuffle()); + addArg(c, "--train-persistent", req.getTrainPersistent()); + addArg(c, "--val-persistent", req.getValPersistent()); + + // ===== Model Architecture ===== + addArg(c, "--drop-path-rate", req.getDropPathRate()); + addArg(c, "--frozen-stages", req.getFrozenStages()); + addArg(c, "--neck-policy", req.getNeckPolicy()); + addArg(c, "--class-weight", req.getClassWeight()); + addArg(c, "--decoder-channels", req.getDecoderChannels()); + + // ===== Loss & Optimization ===== + addArg(c, "--weight-decay", req.getWeightDecay()); + addArg(c, "--layer-decay-rate", req.getLayerDecayRate()); + addArg(c, "--ignore-index", req.getIgnoreIndex()); + addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams()); + addArg(c, "--num-layers", req.getNumLayers()); + + // ===== Evaluation ===== + addArg(c, "--metrics", req.getMetrics()); + addArg(c, "--save-best", req.getSaveBest()); + addArg(c, "--save-best-rule", req.getSaveBestRule()); + addArg(c, "--val-interval", req.getValInterval()); + addArg(c, "--log-interval", req.getLogInterval()); + addArg(c, "--vis-interval", req.getVisInterval()); + + // ===== Augmentation ===== + addArg(c, "--rot-prob", req.getRotProb()); + addArg(c, "--rot-degree", req.getRotDegree()); + addArg(c, "--flip-prob", req.getFlipProb()); + addArg(c, "--exchange-prob", req.getExchangeProb()); + addArg(c, "--brightness-delta", req.getBrightnessDelta()); + addArg(c, "--contrast-range", req.getContrastRange()); + addArg(c, "--saturation-range", req.getSaturationRange()); + addArg(c, "--hue-delta", req.getHueDelta()); + + return c; + } + + /** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */ + private void addArg(List c, String key, Object value) { + if (value == null) return; + String s = String.valueOf(value).trim(); + if (s.isEmpty()) return; + c.add(key); + c.add(s); + } + + /** 컨테이너 강제 종료 및 제거 */ + private void killContainer(String containerName) { + try { + new ProcessBuilder("docker", "rm", "-f", containerName) + .redirectErrorStream(true) + .start() + .waitFor(10, TimeUnit.SECONDS); + } catch (Exception ignored) { + } + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java new file mode 100644 index 0000000..b5e1bd2 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -0,0 +1,119 @@ +package com.kamco.cd.training.train.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.kamco.cd.training.common.enums.TrainStatusType; +import com.kamco.cd.training.model.dto.ModelTrainMngDto; +import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; +import com.kamco.cd.training.train.dto.TrainRunRequest; +import java.time.ZonedDateTime; +import java.util.Map; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional(readOnly = true) +public class TrainJobService { + + private final ModelTrainJobCoreService modelTrainJobCoreService; + private final ModelTrainMngCoreService modelTrainMngCoreService; + private final ObjectMapper objectMapper; + private final ApplicationEventPublisher eventPublisher; + + public Long getModelIdByUuid(UUID uuid) { + return modelTrainMngCoreService.findModelIdByUuid(uuid); + } + + /** 실행 예약 (QUEUE 등록) */ + @Transactional + public Long enqueue(Long modelId) { + + // 마스터 존재 확인(없으면 예외) + modelTrainMngCoreService.findModelById(modelId); + + TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId); + + if (trainRunRequest == null) { + throw new IllegalArgumentException("Model not found: " + modelId); + } + + int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; + + @SuppressWarnings("unchecked") + Map paramsMap = objectMapper.convertValue(trainRunRequest, Map.class); + + Long jobId = + modelTrainJobCoreService.createQueuedJob( + modelId, nextAttemptNo, paramsMap, ZonedDateTime.now()); + + modelTrainMngCoreService.clearLastError(modelId); + modelTrainMngCoreService.markInProgress(modelId, jobId); + + // 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함) + eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); + + return jobId; + } + + /** + * 재시작 버튼 + * + *

- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성 + */ + @Transactional + public Long restart(Long modelId) { + + ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); + + if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) { + throw new IllegalStateException("이미 진행중입니다."); + } + + var lastJob = + modelTrainJobCoreService + .findLatestByModelId(modelId) + .orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다.")); + + int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; + + Long jobId = + modelTrainJobCoreService.createQueuedJob( + modelId, + nextAttemptNo, + lastJob.getParamsJson(), // Map 그대로 재사용 + ZonedDateTime.now()); + + modelTrainMngCoreService.clearLastError(modelId); + modelTrainMngCoreService.markInProgress(modelId, jobId); + + eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); + + return jobId; + } + + /** + * 중단 버튼 + * + *

- job 상태 CANCELED - master 상태 STOPPED + * + *

※ 실제 docker stop은 Worker/Runner가 수행(운영 안정) + */ + @Transactional + public void cancel(Long modelId) { + + ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); + + Long attemptId = master.getCurrentAttemptId(); + if (attemptId == null) { + throw new IllegalStateException("실행중인 작업이 없습니다."); + } + + modelTrainJobCoreService.markCanceled(attemptId); + modelTrainMngCoreService.markStopped(modelId); + } +} diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java new file mode 100644 index 0000000..8acb348 --- /dev/null +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -0,0 +1,87 @@ +package com.kamco.cd.training.train.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; +import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; +import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; +import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; +import com.kamco.cd.training.train.dto.TrainRunRequest; +import com.kamco.cd.training.train.dto.TrainRunResult; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; +import org.springframework.transaction.event.TransactionPhase; +import org.springframework.transaction.event.TransactionalEventListener; + +@Component +@RequiredArgsConstructor +public class TrainJobWorker { + + private final ModelTrainJobCoreService modelTrainJobCoreService; + private final ModelTrainMngCoreService modelTrainMngCoreService; + private final DockerTrainService dockerTrainService; + private final ObjectMapper objectMapper; + + @Async + @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) + public void handle(ModelTrainJobQueuedEvent event) { + + Long jobId = event.getJobId(); // record면 event.jobId() + + ModelTrainJobEntity job = + modelTrainJobCoreService + .findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); + + Long modelId = job.getModelId(); + + // enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적 + TrainRunRequest req = toTrainRunRequest(job.getParamsJson()); + // req가 null이면 실패 처리 + if (req == null) { + modelTrainJobCoreService.markFailed( + jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)"); + modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패"); + return; + } + + // 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리 + String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨 + + // logPath/lockedBy는 너 환경에 맞게 + String logPath = null; + String lockedBy = "TRAIN_WORKER"; + + // RUNNING 표시 + modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy); + + try { + // DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서 + // markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음. + // 아래 "추천 수정" 참고. + TrainRunResult result = dockerTrainService.runTrainSync(req); + + if (result.getExitCode() == 0) { + modelTrainJobCoreService.markSuccess(jobId, result.getExitCode()); + modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰 + } else { + modelTrainJobCoreService.markFailed( + jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); + modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode()); + } + + } catch (Exception e) { + modelTrainJobCoreService.markFailed(jobId, null, e.toString()); + modelTrainMngCoreService.markError(modelId, e.getMessage()); + } + } + + private TrainRunRequest toTrainRunRequest(Map paramsJson) { + if (paramsJson == null || paramsJson.isEmpty()) { + return null; + } + + return objectMapper.convertValue(paramsJson, TrainRunRequest.class); + } +} diff --git a/src/main/resources/application-dev.yml b/src/main/resources/application-dev.yml index 7d59b32..3631352 100644 --- a/src/main/resources/application-dev.yml +++ b/src/main/resources/application-dev.yml @@ -57,3 +57,12 @@ file: dataset-dir: /home/kcomu/data/request/ dataset-tmp-dir: ${file.dataset-dir}tmp/ + +train: + docker: + image: "kamco-cd-train:love_latest" + requestDir: "/home/kcomu/data/request" + responseDir: "/home/kcomu/data/response" + containerPrefix: "kamco-cd-train" + shmSize: "16g" + ipcHost: true diff --git a/src/main/resources/application-prod.yml b/src/main/resources/application-prod.yml index b179771..39f8aec 100644 --- a/src/main/resources/application-prod.yml +++ b/src/main/resources/application-prod.yml @@ -43,3 +43,12 @@ file: dataset-dir: /home/kcomu/data/request/ dataset-tmp-dir: ${file.dataset-dir}tmp/ + +train: + docker: + image: "kamco-cd-train:love_latest" + requestDir: "/home/kcomu/data/request" + responseDir: "/home/kcomu/data/response" + containerPrefix: "kamco-cd-train" + shmSize: "16g" + ipcHost: true