테스트 실행 추가

This commit is contained in:
2026-02-11 21:58:25 +09:00
parent 1249a80da5
commit 2f8bd1f98c
14 changed files with 670 additions and 98 deletions

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@@ -20,12 +21,12 @@ public class ModelTrainJobCoreService {
return modelTrainJobRepository.findMaxAttemptNo(modelId); return modelTrainJobRepository.findMaxAttemptNo(modelId);
} }
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) { public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId); return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
} }
public Optional<ModelTrainJobEntity> findById(Long jobId) { public Optional<ModelTrainJobDto> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId); return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
} }
/** QUEUED Job 생성 */ /** QUEUED Job 생성 */
@@ -95,7 +96,7 @@ public class ModelTrainJobCoreService {
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("CANCELED"); job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
} }
} }

View File

@@ -36,6 +36,7 @@ import org.springframework.transaction.annotation.Transactional;
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class ModelTrainMngCoreService { public class ModelTrainMngCoreService {
private final ModelMngRepository modelMngRepository; private final ModelMngRepository modelMngRepository;
private final ModelDatasetRepository modelDatasetRepository; private final ModelDatasetRepository modelDatasetRepository;
private final ModelDatasetMappRepository modelDatasetMapRepository; private final ModelDatasetMappRepository modelDatasetMapRepository;
@@ -323,7 +324,7 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
} }
/** 오류 처리(옵션) - Worker가 실패 시 호출 */ /** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional @Transactional
public void markError(Long modelId, String errorMessage) { public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -332,7 +333,25 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId()); master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage); master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
} }
@Transactional @Transactional
@@ -358,4 +377,58 @@ public class ModelTrainMngCoreService {
public TrainRunRequest findTrainRunRequest(Long modelId) { public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId); return modelMngRepository.findTrainRunRequest(modelId);
} }
public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now());
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep1Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId());
entity.setStep1EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void markStep2Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId());
entity.setStep2EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
} }

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity; package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import jakarta.persistence.Column; import jakarta.persistence.Column;
import jakarta.persistence.Entity; import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue; import jakarta.persistence.GeneratedValue;
@@ -76,4 +77,19 @@ public class ModelTrainJobEntity {
@Size(max = 100) @Size(max = 100)
@Column(name = "locked_by", length = 100) @Column(name = "locked_by", length = 100)
private String lockedBy; private String lockedBy;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,
this.modelId,
this.attemptNo,
this.statusCd,
this.exitCode,
this.errorMessage,
this.containerName,
this.paramsJson,
this.queuedDttm,
this.startedDttm,
this.finishedDttm);
}
} }

View File

@@ -134,7 +134,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.contrastRange, modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange, modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta, modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class))) Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class)))
.from(modelMasterEntity) .from(modelMasterEntity)
.leftJoin(modelHyperParamEntity) .leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))

View File

@@ -7,6 +7,4 @@ public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId); int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId); Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> pickQueuedForUpdate();
} }

View File

@@ -1,34 +1,43 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManager;
import java.util.Optional; import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@Repository @Repository
@RequiredArgsConstructor
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom { public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final EntityManager em; private final JPAQueryFactory queryFactory;
private JPAQueryFactory queryFactory() { public ModelTrainJobRepositoryImpl(EntityManager em) {
return new JPAQueryFactory(em); this.queryFactory = new JPAQueryFactory(em);
} }
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override @Override
public int findMaxAttemptNo(Long modelId) { public int findMaxAttemptNo(Long modelId) {
return 0; QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
return max != null ? max : 0;
} }
/**
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
* 단순.
*/
@Override @Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) { public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
return Optional.empty(); QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
}
@Override ModelTrainJobEntity job =
public Optional<ModelTrainJobEntity> pickQueuedForUpdate() { queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
return Optional.empty();
return Optional.ofNullable(job);
} }
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train; package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService; import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
@@ -12,6 +13,7 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
@@ -22,6 +24,7 @@ import org.springframework.web.bind.annotation.RestController;
public class TrainApiController { public class TrainApiController {
private final TrainJobService trainJobService; private final TrainJobService trainJobService;
private final TestJobService testJobService;
@Operation(summary = "학습 실행", description = "학습 실행 API") @Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses( @ApiResponses(
@@ -36,7 +39,7 @@ public class TrainApiController {
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content), @ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@RequestMapping("/run/{uuid}") @PostMapping("/run/{uuid}")
public ApiResponseDto<String> run( public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052") @Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable @PathVariable
@@ -45,4 +48,120 @@ public class TrainApiController {
trainJobService.enqueue(modelId); trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok"); return ApiResponseDto.ok("ok");
} }
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "재실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/restart/{uuid}")
public ApiResponseDto<String> restart(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.restart(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "이어하기 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/resume/{uuid}")
public ApiResponseDto<String> resume(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.resume(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/cancel/{uuid}")
public ApiResponseDto<String> cancel(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 실행", description = "test 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "test 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/run/{epoch}/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.enqueue(modelId, uuid, epoch);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/cancel/{uuid}")
public ApiResponseDto<String> cancelTest(
@Parameter(description = "uuid", example = "69c4e56c-e0bf-4742-9225-bba9aae39052")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
} }

View File

@@ -0,0 +1,16 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class EvalRunRequest {
private String uuid;
private int epoch; // best_changed_fscore_epoch_1.pth
private Integer timeoutSeconds;
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.train.dto;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public class ModelTrainJobDto {
private Long id;
private Long modelId;
private Integer attemptNo;
private String statusCd;
private Integer exitCode;
private String errorMessage;
private String containerName;
private Map<String, Object> paramsJson;
private ZonedDateTime queuedDttm;
private ZonedDateTime startedDttm;
private ZonedDateTime finishedDttm;
}

View File

@@ -79,4 +79,5 @@ public class TrainRunRequest {
// 실행 타임아웃 // 실행 타임아웃
// ======================== // ========================
private Integer timeoutSeconds; private Integer timeoutSeconds;
private String resumeFrom;
} }

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult; import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader; import java.io.BufferedReader;
@@ -7,7 +8,6 @@ import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -40,53 +40,72 @@ public class DockerTrainService {
private boolean ipcHost; private boolean ipcHost;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ /** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req) throws Exception { public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
// 실행 식별용 jobId 생성
String jobId = UUID.randomUUID().toString().substring(0, 8);
// 컨테이너 이름 생성 (중복 방지 목적)
String containerName = containerPrefix + "-" + jobId;
// docker run 명령어 조립
List<String> cmd = buildDockerRunCommand(containerName, req); List<String> cmd = buildDockerRunCommand(containerName, req);
// 프로세스 실행
ProcessBuilder pb = new ProcessBuilder(cmd); ProcessBuilder pb = new ProcessBuilder(cmd);
// stderr를 stdout으로 합쳐서 한 스트림으로 처리
pb.redirectErrorStream(true); pb.redirectErrorStream(true);
Process p = pb.start(); Process p = pb.start();
// 실행 로그 수집 // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder log = new StringBuilder(); StringBuilder log = new StringBuilder();
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n');
}
}
} catch (Exception ignored) {
}
},
"train-log-" + containerName);
try (BufferedReader br = logThread.setDaemon(true);
new BufferedReader(new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { logThread.start();
String line; int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
while ((line = br.readLine()) != null) {
log.append(line).append('\n');
}
}
// 지정된 timeout 내에 종료 대기
int timeoutSeconds = 7200; // 기본 2시간
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS); boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
if (!finished) { if (!finished) {
// 타임아웃 발생 시 컨테이너 강제 제거 // docker run 프로세스도 같이 끊어야 readLine이 풀림
p.destroy();
if (!p.waitFor(2, TimeUnit.SECONDS)) {
p.destroyForcibly();
}
killContainer(containerName); killContainer(containerName);
return new TrainRunResult(jobId, containerName, -1, "TIMEOUT", log.toString()); String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(
null, // jobId (없으면 null)
containerName,
-1,
"TIMEOUT",
logs);
} }
// 종료 코드 확인 (0=정상)
int exit = p.exitValue(); int exit = p.exitValue();
return new TrainRunResult( // 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
jobId, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", log.toString()); logThread.join(500);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
} }
/** /**
@@ -205,6 +224,7 @@ public class DockerTrainService {
addArg(c, "--saturation-range", req.getSaturationRange()); addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta()); addArg(c, "--hue-delta", req.getHueDelta());
addArg(c, "--resume-from", req.getResumeFrom());
return c; return c;
} }
@@ -218,7 +238,7 @@ public class DockerTrainService {
} }
/** 컨테이너 강제 종료 및 제거 */ /** 컨테이너 강제 종료 및 제거 */
private void killContainer(String containerName) { public void killContainer(String containerName) {
try { try {
new ProcessBuilder("docker", "rm", "-f", containerName) new ProcessBuilder("docker", "rm", "-f", containerName)
.redirectErrorStream(true) .redirectErrorStream(true)
@@ -227,4 +247,100 @@ public class DockerTrainService {
} catch (Exception ignored) { } catch (Exception ignored) {
} }
} }
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
Process p = pb.start();
StringBuilder log = new StringBuilder();
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n');
}
}
} catch (Exception ignored) {
}
});
logThread.setDaemon(true);
logThread.start();
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
killContainer(containerName);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
}
int exit = p.exitValue();
logThread.join(500);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
}
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
String uuid = req.getUuid();
Integer epoch = req.getEpoch();
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
c.add("--name");
c.add(containerName);
c.add("--rm");
c.add("--gpus");
c.add("all");
if (ipcHost) c.add("--ipc=host");
c.add("--shm-size=" + shmSize);
c.add("-v");
c.add(requestDir + ":/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
c.add(image);
c.add("python");
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
c.add("--dataset_dir");
c.add("/data/" + uuid);
c.add("--model");
c.add("/checkpoints/" + uuid + "/" + modelFile);
return c;
}
} }

View File

@@ -0,0 +1,76 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TestJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
// 마스터 확인
modelTrainMngCoreService.findModelById(modelId);
Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL");
params.put("uuid", uuid);
params.put("epoch", epoch);
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, params, ZonedDateTime.now());
// step2 시작으로 마킹
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
}

View File

@@ -7,10 +7,14 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@@ -22,9 +26,14 @@ public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher; private final ApplicationEventPublisher eventPublisher;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
public Long getModelIdByUuid(UUID uuid) { public Long getModelIdByUuid(UUID uuid) {
return modelTrainMngCoreService.findModelIdByUuid(uuid); return modelTrainMngCoreService.findModelIdByUuid(uuid);
} }
@@ -36,6 +45,7 @@ public class TrainJobService {
// 마스터 존재 확인(없으면 예외) // 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId); modelTrainMngCoreService.findModelById(modelId);
// 파라미터 조회
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId); TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
if (trainRunRequest == null) { if (trainRunRequest == null) {
@@ -46,6 +56,7 @@ public class TrainJobService {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class); Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
paramsMap.put("jobType", "TRAIN");
Long jobId = Long jobId =
modelTrainJobCoreService.createQueuedJob( modelTrainJobCoreService.createQueuedJob(
@@ -57,16 +68,66 @@ public class TrainJobService {
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함) // 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
return jobId; return jobId;
} }
/** /**
* 재시작 버튼 * 재시작
* *
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성 * <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
*/ */
@Transactional @Transactional
public Long restart(Long modelId) { public Long restart(Long modelId) {
return createNextAttempt(modelId, ResumeMode.NONE);
}
/**
* 이어하기
*
* @param modelId
* @return
*/
@Transactional
public Long resume(Long modelId) {
return createNextAttempt(modelId, ResumeMode.REQUIRE);
}
/**
* 중단
*
* <p>- job 상태 CANCELED - master 상태 STOPPED
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
@@ -81,39 +142,72 @@ public class TrainJobService {
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
// 이전 params_json 재사용 (재현성)
Map<String, Object> params = lastJob.getParamsJson();
if (params == null || params.isEmpty()) {
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
}
// mode에 따라 resume 옵션 주입/제거
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
if (mode == ResumeMode.NONE) {
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
nextParams.remove("resumeFrom");
nextParams.remove("resume");
} else if (mode == ResumeMode.REQUIRE) {
// 체크포인트 탐지해서 resumeFrom 세팅
String resumeFrom = findResumeFromOrNull(nextParams);
if (resumeFrom == null) {
throw new IllegalStateException("이어하기 체크포인트가 없습니다.");
}
nextParams.put("resumeFrom", resumeFrom);
nextParams.put("resume", true);
}
Long jobId = Long jobId =
modelTrainJobCoreService.createQueuedJob( modelTrainJobCoreService.createQueuedJob(
modelId, modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
nextAttemptNo,
lastJob.getParamsJson(), // Map<String,Object> 그대로 재사용
ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId); modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId); modelTrainMngCoreService.markInProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId; return jobId;
} }
/** private enum ResumeMode {
* 중단 버튼 NONE, // 새로 시작
* REQUIRE // 이어하기
* <p>- job 상태 CANCELED - master 상태 STOPPED }
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
Long attemptId = master.getCurrentAttemptId(); Object out = paramsJson.get("outputFolder");
if (attemptId == null) { if (out == null) return null;
throw new IllegalStateException("실행중인 작업이 없습니다.");
String outputFolder = String.valueOf(out).trim(); // uuid
if (outputFolder.isEmpty()) return null;
// 호스트 기준 경로
Path outDir = Paths.get(responseDir, outputFolder);
Path last = outDir.resolve("last_checkpoint");
if (!Files.isRegularFile(last)) return null;
try {
String ckptFile = Files.readString(last).trim(); // epoch_10.pth
if (ckptFile.isEmpty()) return null;
Path ckptHost = outDir.resolve(ckptFile);
if (!Files.isRegularFile(ckptHost)) return null;
// 컨테이너 경로 반환
return "/checkpoints/" + outputFolder + "/" + ckptFile;
} catch (Exception e) {
return null;
} }
modelTrainJobCoreService.markCanceled(attemptId);
modelTrainMngCoreService.markStopped(modelId);
} }
} }

View File

@@ -1,9 +1,11 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult; import com.kamco.cd.training.train.dto.TrainRunResult;
@@ -27,53 +29,80 @@ public class TrainJobWorker {
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) { public void handle(ModelTrainJobQueuedEvent event) {
Long jobId = event.getJobId(); // record면 event.jobId() Long jobId = event.getJobId();
ModelTrainJobEntity job = ModelTrainJobDto job =
modelTrainJobCoreService modelTrainJobCoreService
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
Long modelId = job.getModelId(); if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
// enqueue에서 params_json 저장해놨으니 그걸로 TrainRunRequest 복원하는게 제일 일관적
TrainRunRequest req = toTrainRunRequest(job.getParamsJson());
// req가 null이면 실패 처리
if (req == null) {
modelTrainJobCoreService.markFailed(
jobId, null, "TrainRunRequest 변환 실패 (params_json null/invalid)");
modelTrainMngCoreService.markError(modelId, "TrainRunRequest 변환 실패");
return; return;
} }
// 컨테이너 이름은 "jobId 기반"으로 고정하는 게 cancel/restart에 유리 Long modelId = job.getModelId();
String containerName = "train-" + jobId; // prefix 쓰고싶으면 @Value 받아서 붙이면 됨 Map<String, Object> params = job.getParamsJson();
// logPath/lockedBy는 너 환경에 맞게 String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
String logPath = null;
String lockedBy = "TRAIN_WORKER";
// RUNNING 표시 boolean isEval = "EVAL".equals(jobType);
modelTrainJobCoreService.markRunning(jobId, containerName, logPath, lockedBy);
String containerName = (isEval ? "eval-" : "train-") + jobId;
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
try { try {
// DockerTrainService가 내부에서 컨테이너 이름을 랜덤으로 만들고 있어서 TrainRunResult result;
// markRunning에서 저장한 containerName과 실제 컨테이너명이 달라질 수 있음.
// 아래 "추천 수정" 참고. if (isEval) {
TrainRunResult result = dockerTrainService.runTrainSync(req); String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
result = dockerTrainService.runEvalSync(evalReq, containerName);
} else {
TrainRunRequest trainReq = toTrainRunRequest(params);
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
ModelTrainJobDto latest =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
return;
}
if (result.getExitCode() == 0) { if (result.getExitCode() == 0) {
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode()); modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
modelTrainMngCoreService.markSuccess(modelId); // 너 modelTrainMngCoreService에 있는 이름으로 맞춰
if (isEval) {
modelTrainMngCoreService.markStep2Success(modelId);
} else {
modelTrainMngCoreService.markStep1Success(modelId);
}
} else { } else {
modelTrainJobCoreService.markFailed( modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
if (isEval) {
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
} }
} catch (Exception e) { } catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.toString()); modelTrainJobCoreService.markFailed(jobId, null, e.toString());
modelTrainMngCoreService.markError(modelId, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) {
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
} }
} }