실행 오류 수정 #42

Merged
teddy merged 1 commits from feat/training_260202 into develop 2026-02-12 10:14:54 +09:00
7 changed files with 67 additions and 25 deletions
Showing only changes of commit 3367d0e7be - Show all commits

View File

@@ -31,6 +31,7 @@ import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@Service @Service
@@ -378,6 +379,13 @@ public class ModelTrainMngCoreService {
return modelMngRepository.findTrainRunRequest(modelId); return modelMngRepository.findTrainRunRequest(modelId);
} }
/**
* step1 진행중 처리
*
* @param modelId
* @param jobId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) { public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
@@ -392,6 +400,12 @@ public class ModelTrainMngCoreService {
entity.setUpdatedUid(userUtil.getId()); entity.setUpdatedUid(userUtil.getId());
} }
/**
* step2 진행중 처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) { public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
@@ -406,6 +420,12 @@ public class ModelTrainMngCoreService {
entity.setUpdatedUid(userUtil.getId()); entity.setUpdatedUid(userUtil.getId());
} }
/**
* step1 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) { public void markStep1Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository
@@ -419,6 +439,12 @@ public class ModelTrainMngCoreService {
entity.setUpdatedUid(userUtil.getId()); entity.setUpdatedUid(userUtil.getId());
} }
/**
* step2 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) { public void markStep2Success(Long modelId) {
ModelMasterEntity entity = ModelMasterEntity entity =
modelMngRepository modelMngRepository

View File

@@ -90,7 +90,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
@Override @Override
public TrainRunRequest findTrainRunRequest(Long modelId) { public TrainRunRequest findTrainRunRequest(Long modelId) {
queryFactory return queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
TrainRunRequest.class, TrainRunRequest.class,
@@ -141,7 +141,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(modelConfigEntity) .leftJoin(modelConfigEntity)
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id)) .on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne(); .fetchOne();
return null;
} }
} }

View File

@@ -41,7 +41,7 @@ public class TrainApiController {
}) })
@PostMapping("/run/{uuid}") @PostMapping("/run/{uuid}")
public ApiResponseDto<String> run( public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
@@ -64,7 +64,7 @@ public class TrainApiController {
}) })
@PostMapping("/restart/{uuid}") @PostMapping("/restart/{uuid}")
public ApiResponseDto<String> restart( public ApiResponseDto<String> restart(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
@@ -87,7 +87,7 @@ public class TrainApiController {
}) })
@PostMapping("/resume/{uuid}") @PostMapping("/resume/{uuid}")
public ApiResponseDto<String> resume( public ApiResponseDto<String> resume(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
@@ -110,7 +110,7 @@ public class TrainApiController {
}) })
@PostMapping("/cancel/{uuid}") @PostMapping("/cancel/{uuid}")
public ApiResponseDto<String> cancel( public ApiResponseDto<String> cancel(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
@@ -134,7 +134,7 @@ public class TrainApiController {
@PostMapping("/test/run/{epoch}/{uuid}") @PostMapping("/test/run/{epoch}/{uuid}")
public ApiResponseDto<String> run( public ApiResponseDto<String> run(
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch, @Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
@@ -157,7 +157,7 @@ public class TrainApiController {
}) })
@PostMapping("/test/cancel/{uuid}") @PostMapping("/test/cancel/{uuid}")
public ApiResponseDto<String> cancelTest( public ApiResponseDto<String> cancelTest(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.train.dto; package com.kamco.cd.training.train.dto;
import java.util.UUID;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@@ -14,8 +15,8 @@ public class TrainRunRequest {
// ======================== // ========================
// 기본 // 기본
// ======================== // ========================
private String datasetFolder; private UUID datasetFolder;
private String outputFolder; private UUID outputFolder;
private String inputSize; private String inputSize;
private String cropSize; private String cropSize;
private Integer batchSize; private Integer batchSize;
@@ -80,4 +81,12 @@ public class TrainRunRequest {
// ======================== // ========================
private Integer timeoutSeconds; private Integer timeoutSeconds;
private String resumeFrom; private String resumeFrom;
public String getDatasetFolder() {
return String.valueOf(datasetFolder);
}
public String getOutputFolder() {
return String.valueOf(outputFolder);
}
} }

View File

@@ -9,9 +9,11 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Log4j2
@Service @Service
public class DockerTrainService { public class DockerTrainService {
@@ -44,6 +46,11 @@ public class DockerTrainService {
List<String> cmd = buildDockerRunCommand(containerName, req); List<String> cmd = buildDockerRunCommand(containerName, req);
log.info("=== Docker Train Command ===");
log.info("Container: {}", containerName);
log.info("Command: {}", String.join(" ", cmd));
log.info("================================");
ProcessBuilder pb = new ProcessBuilder(cmd); ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true); pb.redirectErrorStream(true);
@@ -121,23 +128,11 @@ public class DockerTrainService {
// 컨테이너 이름 지정 // 컨테이너 이름 지정
c.add("--name"); c.add("--name");
c.add(containerName); c.add(containerName + "-" + req.getOutputFolder().substring(0, 8));
// 실행 종료 시 자동 삭제 // 실행 종료 시 자동 삭제
c.add("--rm"); c.add("--rm");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// GPU 전체 사용 // GPU 전체 사용
c.add("--gpus"); c.add("--gpus");
c.add("all"); c.add("all");
@@ -156,6 +151,18 @@ public class DockerTrainService {
c.add("--ulimit"); c.add("--ulimit");
c.add("stack=67108864"); c.add("stack=67108864");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// 요청/결과 디렉토리 볼륨 마운트 // 요청/결과 디렉토리 볼륨 마운트
c.add("-v"); c.add("-v");
c.add(requestDir + ":/data"); c.add(requestDir + ":/data");

View File

@@ -67,8 +67,6 @@ public class TrainJobService {
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함) // 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
return jobId; return jobId;
} }

View File

@@ -55,6 +55,7 @@ public class TrainJobWorker {
TrainRunResult result; TrainRunResult result;
if (isEval) { if (isEval) {
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid")); String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch"); int epoch = (int) params.get("epoch");
@@ -62,6 +63,7 @@ public class TrainJobWorker {
result = dockerTrainService.runEvalSync(evalReq, containerName); result = dockerTrainService.runEvalSync(evalReq, containerName);
} else { } else {
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params); TrainRunRequest trainReq = toTrainRunRequest(params);
result = dockerTrainService.runTrainSync(trainReq, containerName); result = dockerTrainService.runTrainSync(trainReq, containerName);
} }