Compare commits
24 Commits
8c45b39dcc
...
feat/train
| Author | SHA1 | Date | |
|---|---|---|---|
| 265813e6f7 | |||
| 8190a6e9c8 | |||
| e9f8bb37fa | |||
| f3c822587f | |||
| 335f0dbb9b | |||
| 69eaba1a83 | |||
| 365ad81cad | |||
| 9dfa54fbf9 | |||
| 12f6bb7154 | |||
| aa3af4e9d0 | |||
| 7ca37bf1e4 | |||
| 901dde066d | |||
| cb0a38274a | |||
| b8194df9ae | |||
| 7c5f07683e | |||
| 159fb281d4 | |||
| 97192ff811 | |||
| 4f3fb675be | |||
| e6caea05b3 | |||
| fd63824edc | |||
| 8a44df26b8 | |||
| cb97c5e59e | |||
| 8f75b16dc6 | |||
| c2978e41c2 |
@@ -2,10 +2,8 @@ package com.kamco.cd.training;
|
|||||||
|
|
||||||
import org.springframework.boot.SpringApplication;
|
import org.springframework.boot.SpringApplication;
|
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
import org.springframework.scheduling.annotation.EnableAsync;
|
|
||||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||||
|
|
||||||
@EnableAsync
|
|
||||||
@SpringBootApplication
|
@SpringBootApplication
|
||||||
@EnableScheduling
|
@EnableScheduling
|
||||||
public class KamcoTrainingApplication {
|
public class KamcoTrainingApplication {
|
||||||
|
|||||||
@@ -12,13 +12,14 @@ import lombok.Setter;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class HyperParam {
|
public class HyperParam {
|
||||||
|
|
||||||
|
@Schema(description = "모델", example = "G1")
|
||||||
|
private ModelType model; // G1, G2, G3
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Important
|
// Important
|
||||||
// -------------------------
|
// -------------------------
|
||||||
|
|
||||||
@Schema(description = "모델", example = "large")
|
|
||||||
private ModelType model; // backbone
|
|
||||||
|
|
||||||
@Schema(description = "백본 네트워크", example = "large")
|
@Schema(description = "백본 네트워크", example = "large")
|
||||||
private String backbone; // backbone
|
private String backbone; // backbone
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package com.kamco.cd.training.common.enums;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public enum JobStatusType implements EnumType {
|
||||||
|
QUEUED("대기중"),
|
||||||
|
RUNNING("실행중"),
|
||||||
|
SUCCESS("성공"),
|
||||||
|
FAILED("실패"),
|
||||||
|
CANCELED("취소");
|
||||||
|
|
||||||
|
private final String desc;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getId() {
|
||||||
|
return name();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getText() {
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.kamco.cd.training.common.enums;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public enum JobType implements EnumType {
|
||||||
|
TRAIN("학습"),
|
||||||
|
TEST("테스트");
|
||||||
|
|
||||||
|
private final String desc;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getId() {
|
||||||
|
return name();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getText() {
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package com.kamco.cd.training.common.utils;
|
|||||||
import static java.lang.String.CASE_INSENSITIVE_ORDER;
|
import static java.lang.String.CASE_INSENSITIVE_ORDER;
|
||||||
|
|
||||||
import com.jcraft.jsch.ChannelExec;
|
import com.jcraft.jsch.ChannelExec;
|
||||||
import com.jcraft.jsch.ChannelSftp;
|
|
||||||
import com.jcraft.jsch.JSch;
|
import com.jcraft.jsch.JSch;
|
||||||
import com.jcraft.jsch.Session;
|
import com.jcraft.jsch.Session;
|
||||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
@@ -720,18 +719,26 @@ public class FIleChecker {
|
|||||||
public static void unzip(String fileName, String destDirectory) throws IOException {
|
public static void unzip(String fileName, String destDirectory) throws IOException {
|
||||||
String zipFilePath = destDirectory + File.separator + fileName;
|
String zipFilePath = destDirectory + File.separator + fileName;
|
||||||
|
|
||||||
|
log.info("fileName : {}", fileName);
|
||||||
|
log.info("destDirectory : {}", destDirectory);
|
||||||
|
log.info("zipFilePath : {}", zipFilePath);
|
||||||
// zip 이름으로 폴더 생성 (확장자 제거)
|
// zip 이름으로 폴더 생성 (확장자 제거)
|
||||||
String folderName =
|
String folderName =
|
||||||
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
|
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
|
||||||
|
log.info("folderName : {}", folderName);
|
||||||
|
|
||||||
File destDir = new File(destDirectory, folderName);
|
File destDir = new File(destDirectory, folderName);
|
||||||
|
log.info("destDir : {}", destDir);
|
||||||
|
|
||||||
// 동일 폴더가 이미 있으면 삭제
|
// 동일 폴더가 이미 있으면 삭제
|
||||||
|
log.info("111 destDir.exists() : {}", destDir.exists());
|
||||||
if (destDir.exists()) {
|
if (destDir.exists()) {
|
||||||
deleteDirectoryRecursively(destDir.toPath());
|
deleteDirectoryRecursively(destDir.toPath());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.info("222 destDir.exists() : {}", destDir.exists());
|
||||||
if (!destDir.exists()) {
|
if (!destDir.exists()) {
|
||||||
|
log.info("mkdirs : {}", destDir.exists());
|
||||||
destDir.mkdirs();
|
destDir.mkdirs();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -787,92 +794,6 @@ public class FIleChecker {
|
|||||||
return destFile;
|
return destFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void uploadTo86(Path localFile) {
|
|
||||||
|
|
||||||
String host = "192.168.2.86";
|
|
||||||
int port = 22;
|
|
||||||
String username = "kcomu";
|
|
||||||
String password = "Kamco2025!";
|
|
||||||
|
|
||||||
String remoteDir = "/home/kcomu/data/request";
|
|
||||||
|
|
||||||
Session session = null;
|
|
||||||
ChannelSftp channel = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
JSch jsch = new JSch();
|
|
||||||
|
|
||||||
session = jsch.getSession(username, host, port);
|
|
||||||
session.setPassword(password);
|
|
||||||
|
|
||||||
Properties config = new Properties();
|
|
||||||
config.put("StrictHostKeyChecking", "no");
|
|
||||||
session.setConfig(config);
|
|
||||||
|
|
||||||
session.connect(10_000);
|
|
||||||
|
|
||||||
channel = (ChannelSftp) session.openChannel("sftp");
|
|
||||||
channel.connect(10_000);
|
|
||||||
|
|
||||||
// 목적지 디렉토리 이동
|
|
||||||
channel.cd(remoteDir);
|
|
||||||
|
|
||||||
// 업로드
|
|
||||||
channel.put(localFile.toString(), localFile.getFileName().toString());
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException("SFTP upload failed", e);
|
|
||||||
} finally {
|
|
||||||
if (channel != null) channel.disconnect();
|
|
||||||
if (session != null) session.disconnect();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void unzipOn86Server(String zipPath, String targetDir) {
|
|
||||||
|
|
||||||
String host = "192.168.2.86";
|
|
||||||
String user = "kcomu";
|
|
||||||
String password = "Kamco2025!";
|
|
||||||
|
|
||||||
Session session = null;
|
|
||||||
ChannelExec channel = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
JSch jsch = new JSch();
|
|
||||||
|
|
||||||
session = jsch.getSession(user, host, 22);
|
|
||||||
session.setPassword(password);
|
|
||||||
|
|
||||||
Properties config = new Properties();
|
|
||||||
config.put("StrictHostKeyChecking", "no");
|
|
||||||
session.setConfig(config);
|
|
||||||
|
|
||||||
session.connect(10_000);
|
|
||||||
|
|
||||||
String command = "unzip -o " + zipPath + " -d " + targetDir;
|
|
||||||
|
|
||||||
channel = (ChannelExec) session.openChannel("exec");
|
|
||||||
channel.setCommand(command);
|
|
||||||
channel.setErrStream(System.err);
|
|
||||||
|
|
||||||
InputStream in = channel.getInputStream();
|
|
||||||
channel.connect();
|
|
||||||
|
|
||||||
// 출력 읽기(선택)
|
|
||||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
|
|
||||||
while (br.readLine() != null) {
|
|
||||||
// 필요하면 로그
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
} finally {
|
|
||||||
if (channel != null) channel.disconnect();
|
|
||||||
if (session != null) session.disconnect();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<String> execCommandAndReadLines(String command) {
|
public static List<String> execCommandAndReadLines(String command) {
|
||||||
|
|
||||||
List<String> result = new ArrayList<>();
|
List<String> result = new ArrayList<>();
|
||||||
|
|||||||
23
src/main/java/com/kamco/cd/training/config/AsyncConfig.java
Normal file
23
src/main/java/com/kamco/cd/training/config/AsyncConfig.java
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package com.kamco.cd.training.config;
|
||||||
|
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.scheduling.annotation.EnableAsync;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@EnableAsync
|
||||||
|
public class AsyncConfig {
|
||||||
|
|
||||||
|
@Bean(name = "trainJobExecutor")
|
||||||
|
public Executor trainJobExecutor() {
|
||||||
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
|
executor.setCorePoolSize(4); // 동시에 4개 실행
|
||||||
|
executor.setMaxPoolSize(8); // 최대 8개
|
||||||
|
executor.setQueueCapacity(200); // 대기 큐
|
||||||
|
executor.setThreadNamePrefix("train-job-");
|
||||||
|
executor.initialize();
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -417,6 +417,7 @@ public class DatasetService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String escape(String path) {
|
private String escape(String path) {
|
||||||
|
// 쉘 커맨드에서 안전하게 사용할 수 있도록 문자열을 작은따옴표로 감싸면서, 내부의 작은따옴표를 이스케이프 처리
|
||||||
return "'" + path.replace("'", "'\"'\"'") + "'";
|
return "'" + path.replace("'", "'\"'\"'") + "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,9 +97,9 @@ public class HyperParamApiController {
|
|||||||
String type,
|
String type,
|
||||||
@Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false)
|
@Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false)
|
||||||
LocalDate startDate,
|
LocalDate startDate,
|
||||||
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false)
|
@Parameter(description = "종료일", example = "2026-03-31") @RequestParam(required = false)
|
||||||
LocalDate endDate,
|
LocalDate endDate,
|
||||||
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
@Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false)
|
||||||
String hyperVer,
|
String hyperVer,
|
||||||
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
|
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
|
||||||
@RequestParam(required = false)
|
@RequestParam(required = false)
|
||||||
@@ -142,7 +142,7 @@ public class HyperParamApiController {
|
|||||||
})
|
})
|
||||||
@DeleteMapping("/{uuid}")
|
@DeleteMapping("/{uuid}")
|
||||||
public ApiResponseDto<Void> deleteHyperParam(
|
public ApiResponseDto<Void> deleteHyperParam(
|
||||||
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
|
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
|
||||||
@PathVariable
|
@PathVariable
|
||||||
UUID uuid) {
|
UUID uuid) {
|
||||||
hyperParamService.deleteHyperParam(uuid);
|
hyperParamService.deleteHyperParam(uuid);
|
||||||
@@ -164,7 +164,7 @@ public class HyperParamApiController {
|
|||||||
})
|
})
|
||||||
@GetMapping("/{uuid}")
|
@GetMapping("/{uuid}")
|
||||||
public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
|
public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
|
||||||
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
|
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
|
||||||
@PathVariable
|
@PathVariable
|
||||||
UUID uuid) {
|
UUID uuid) {
|
||||||
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));
|
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ public class HyperParamDto {
|
|||||||
private UUID uuid;
|
private UUID uuid;
|
||||||
private String hyperVer;
|
private String hyperVer;
|
||||||
@JsonFormatDttm private ZonedDateTime createdDttm;
|
@JsonFormatDttm private ZonedDateTime createdDttm;
|
||||||
|
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||||
|
private Integer totalUseCnt;
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Important
|
// Important
|
||||||
@@ -115,10 +117,7 @@ public class HyperParamDto {
|
|||||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||||
private String memo;
|
private String memo;
|
||||||
private Long m1UseCnt;
|
private Integer totalUseCnt;
|
||||||
private Long m2UseCnt;
|
|
||||||
private Long m3UseCnt;
|
|
||||||
private Long totalCnt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ public class ModelTrainDetailDto {
|
|||||||
@JsonFormatDttm private ZonedDateTime step2EndDttm;
|
@JsonFormatDttm private ZonedDateTime step2EndDttm;
|
||||||
private String statusCd;
|
private String statusCd;
|
||||||
private String trainType;
|
private String trainType;
|
||||||
|
private UUID beforeUuid;
|
||||||
|
|
||||||
public String getStatusName() {
|
public String getStatusName() {
|
||||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ public class ModelTrainMngService {
|
|||||||
// 모델 config 저장
|
// 모델 config 저장
|
||||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||||
|
|
||||||
// 임시파일 생성
|
// 데이터셋 임시파일 생성
|
||||||
trainJobService.createTmpFile(modelUuid);
|
trainJobService.createTmpFile(modelUuid);
|
||||||
return modelUuid;
|
return modelUuid;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ public class HyperParamCoreService {
|
|||||||
|
|
||||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||||
entity.setHyperVer(firstVersion);
|
entity.setHyperVer(firstVersion);
|
||||||
|
|
||||||
applyHyperParam(entity, createReq);
|
applyHyperParam(entity, createReq);
|
||||||
|
|
||||||
// user
|
// user
|
||||||
@@ -104,7 +103,7 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository.getHyperparamByType(model).stream()
|
hyperParamRepository.getHyperParamByType(model).stream()
|
||||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
.findFirst()
|
.findFirst()
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
@@ -172,7 +171,7 @@ public class HyperParamCoreService {
|
|||||||
} else {
|
} else {
|
||||||
entity.setCropSize("256,256");
|
entity.setCropSize("256,256");
|
||||||
}
|
}
|
||||||
// entity.setCropSize(src.getCropSize());
|
entity.setCropSize(src.getCropSize());
|
||||||
|
|
||||||
// Important
|
// Important
|
||||||
entity.setModelType(model); // 20250212 modeltype추가
|
entity.setModelType(model); // 20250212 modeltype추가
|
||||||
|
|||||||
@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
|
|||||||
return modelDetailRepository.getModelDetailSummary(uuid);
|
return modelDetailRepository.getModelDetailSummary(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 하이퍼 파리미터 요약정보
|
||||||
|
*
|
||||||
|
* @param uuid 모델마스터 uuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
||||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
import java.time.ZonedDateTime;
|
import java.time.ZonedDateTime;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
@@ -52,11 +56,16 @@ public class ModelTrainJobCoreService {
|
|||||||
/** 실행 시작 처리 */
|
/** 실행 시작 처리 */
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markRunning(
|
public void markRunning(
|
||||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
Long jobId,
|
||||||
|
String containerName,
|
||||||
|
String logPath,
|
||||||
|
String lockedBy,
|
||||||
|
Integer totalEpoch,
|
||||||
|
String jobType) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
job.setStatusCd("RUNNING");
|
job.setStatusCd("RUNNING");
|
||||||
job.setContainerName(containerName);
|
job.setContainerName(containerName);
|
||||||
@@ -64,32 +73,44 @@ public class ModelTrainJobCoreService {
|
|||||||
job.setStartedDttm(ZonedDateTime.now());
|
job.setStartedDttm(ZonedDateTime.now());
|
||||||
job.setLockedDttm(ZonedDateTime.now());
|
job.setLockedDttm(ZonedDateTime.now());
|
||||||
job.setLockedBy(lockedBy);
|
job.setLockedBy(lockedBy);
|
||||||
|
job.setJobType(jobType);
|
||||||
|
|
||||||
if (totalEpoch != null) {
|
if (totalEpoch != null) {
|
||||||
job.setTotalEpoch(totalEpoch);
|
job.setTotalEpoch(totalEpoch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 성공 처리 */
|
/**
|
||||||
|
* 성공 처리
|
||||||
|
*
|
||||||
|
* @param jobId
|
||||||
|
* @param exitCode
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markSuccess(Long jobId, int exitCode) {
|
public void markSuccess(Long jobId, int exitCode) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
job.setStatusCd("SUCCESS");
|
job.setStatusCd("SUCCESS");
|
||||||
job.setExitCode(exitCode);
|
job.setExitCode(exitCode);
|
||||||
job.setFinishedDttm(ZonedDateTime.now());
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 실패 처리 */
|
/**
|
||||||
|
* 실패 처리
|
||||||
|
*
|
||||||
|
* @param jobId
|
||||||
|
* @param exitCode
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
job.setStatusCd("FAILED");
|
job.setStatusCd("FAILED");
|
||||||
job.setExitCode(exitCode);
|
job.setExitCode(exitCode);
|
||||||
@@ -99,13 +120,35 @@ public class ModelTrainJobCoreService {
|
|||||||
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
|
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 중단됨 처리
|
||||||
|
*
|
||||||
|
* @param jobId
|
||||||
|
* @param exitCode
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public void markPaused(Long jobId, Integer exitCode, String errorMessage) {
|
||||||
|
ModelTrainJobEntity job =
|
||||||
|
modelTrainJobRepository
|
||||||
|
.findById(jobId)
|
||||||
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
|
job.setStatusCd("STOPPED");
|
||||||
|
job.setExitCode(exitCode);
|
||||||
|
job.setErrorMessage(errorMessage);
|
||||||
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
|
|
||||||
|
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
/** 취소 처리 */
|
/** 취소 처리 */
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markCanceled(Long jobId) {
|
public void markCanceled(Long jobId) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
job.setStatusCd("STOPPED");
|
job.setStatusCd("STOPPED");
|
||||||
job.setFinishedDttm(ZonedDateTime.now());
|
job.setFinishedDttm(ZonedDateTime.now());
|
||||||
@@ -116,7 +159,7 @@ public class ModelTrainJobCoreService {
|
|||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findByContainerName(containerName)
|
.findByContainerName(containerName)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
job.setCurrentEpoch(epoch);
|
job.setCurrentEpoch(epoch);
|
||||||
|
|
||||||
@@ -126,4 +169,19 @@ public class ModelTrainJobCoreService {
|
|||||||
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
|
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
|
||||||
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
|
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 실행중인 학습이 있는지 조회
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public List<ModelTrainJobDto> findRunningJobs() {
|
||||||
|
List<ModelTrainJobEntity> entity = modelTrainJobRepository.findRunningJobs();
|
||||||
|
|
||||||
|
if (entity == null || entity.isEmpty()) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
return entity.stream().map(ModelTrainJobEntity::toDto).toList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ public class ModelTrainMngCoreService {
|
|||||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||||
hyperParamEntity =
|
hyperParamEntity =
|
||||||
hyperParamRepository.getHyperparamByType(modelType).stream()
|
hyperParamRepository.getHyperParamByType(modelType).stream()
|
||||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
.findFirst()
|
.findFirst()
|
||||||
.orElse(null);
|
.orElse(null);
|
||||||
@@ -104,6 +104,12 @@ public class ModelTrainMngCoreService {
|
|||||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||||
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
|
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
|
||||||
}
|
}
|
||||||
|
// 하이퍼 파라미터 사용 횟수 업데이트
|
||||||
|
hyperParamEntity.setTotalUseCnt(
|
||||||
|
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
|
||||||
|
|
||||||
|
// 최근 사용일시 업데이트
|
||||||
|
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
|
||||||
|
|
||||||
String modelVer =
|
String modelVer =
|
||||||
String.join(
|
String.join(
|
||||||
@@ -384,7 +390,12 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -399,7 +410,12 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param errorMessage
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
|
|||||||
@@ -192,10 +192,10 @@ public class ModelHyperParamEntity {
|
|||||||
@Column(name = "save_best_rule", nullable = false, length = 10)
|
@Column(name = "save_best_rule", nullable = false, length = 10)
|
||||||
private String saveBestRule = "greater";
|
private String saveBestRule = "greater";
|
||||||
|
|
||||||
/** Default: 10 */
|
/** Default: 1 */
|
||||||
@NotNull
|
@NotNull
|
||||||
@Column(name = "val_interval", nullable = false)
|
@Column(name = "val_interval", nullable = false)
|
||||||
private Integer valInterval = 10;
|
private Integer valInterval = 1;
|
||||||
|
|
||||||
/** Default: 400 */
|
/** Default: 400 */
|
||||||
@NotNull
|
@NotNull
|
||||||
@@ -303,15 +303,6 @@ public class ModelHyperParamEntity {
|
|||||||
@Column(name = "last_used_dttm")
|
@Column(name = "last_used_dttm")
|
||||||
private ZonedDateTime lastUsedDttm;
|
private ZonedDateTime lastUsedDttm;
|
||||||
|
|
||||||
@Column(name = "m1_use_cnt")
|
|
||||||
private Long m1UseCnt = 0L;
|
|
||||||
|
|
||||||
@Column(name = "m2_use_cnt")
|
|
||||||
private Long m2UseCnt = 0L;
|
|
||||||
|
|
||||||
@Column(name = "m3_use_cnt")
|
|
||||||
private Long m3UseCnt = 0L;
|
|
||||||
|
|
||||||
@Column(name = "model_type")
|
@Column(name = "model_type")
|
||||||
@Enumerated(EnumType.STRING)
|
@Enumerated(EnumType.STRING)
|
||||||
private ModelType modelType;
|
private ModelType modelType;
|
||||||
@@ -319,12 +310,17 @@ public class ModelHyperParamEntity {
|
|||||||
@Column(name = "default_param")
|
@Column(name = "default_param")
|
||||||
private Boolean isDefault = false;
|
private Boolean isDefault = false;
|
||||||
|
|
||||||
|
@Column(name = "total_use_cnt")
|
||||||
|
private Integer totalUseCnt = 0;
|
||||||
|
|
||||||
public HyperParamDto.Basic toDto() {
|
public HyperParamDto.Basic toDto() {
|
||||||
return new HyperParamDto.Basic(
|
return new HyperParamDto.Basic(
|
||||||
this.modelType,
|
this.modelType,
|
||||||
this.uuid,
|
this.uuid,
|
||||||
this.hyperVer,
|
this.hyperVer,
|
||||||
this.createdDttm,
|
this.createdDttm,
|
||||||
|
this.lastUsedDttm,
|
||||||
|
this.totalUseCnt,
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Important
|
// Important
|
||||||
// -------------------------
|
// -------------------------
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
|
|||||||
@Column(name = "current_epoch")
|
@Column(name = "current_epoch")
|
||||||
private Integer currentEpoch;
|
private Integer currentEpoch;
|
||||||
|
|
||||||
|
@Column(name = "job_type")
|
||||||
|
private String jobType;
|
||||||
|
|
||||||
public ModelTrainJobDto toDto() {
|
public ModelTrainJobDto toDto() {
|
||||||
return new ModelTrainJobDto(
|
return new ModelTrainJobDto(
|
||||||
this.id,
|
this.id,
|
||||||
|
|||||||
@@ -29,9 +29,28 @@ public interface HyperParamRepositoryCustom {
|
|||||||
|
|
||||||
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
|
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 하이퍼 파라미터 상세조회
|
||||||
|
*
|
||||||
|
* @param uuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
|
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 하이퍼 파라미터 목록 조회
|
||||||
|
*
|
||||||
|
* @param model
|
||||||
|
* @param req
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||||
|
|
||||||
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
/**
|
||||||
|
* 하이퍼 파라미터 모델타입으로 조회
|
||||||
|
*
|
||||||
|
* @param modelType
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
|||||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||||
import com.querydsl.core.BooleanBuilder;
|
import com.querydsl.core.BooleanBuilder;
|
||||||
import com.querydsl.core.types.Projections;
|
import com.querydsl.core.types.Projections;
|
||||||
import com.querydsl.core.types.dsl.NumberExpression;
|
|
||||||
import com.querydsl.jpa.impl.JPAQuery;
|
import com.querydsl.jpa.impl.JPAQuery;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
import java.time.ZoneId;
|
import java.time.ZoneId;
|
||||||
@@ -82,7 +81,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
|
.where(modelHyperParamEntity.uuid.eq(uuid))
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,10 +90,12 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
Pageable pageable = req.toPageable();
|
Pageable pageable = req.toPageable();
|
||||||
|
|
||||||
BooleanBuilder builder = new BooleanBuilder();
|
BooleanBuilder builder = new BooleanBuilder();
|
||||||
|
|
||||||
|
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||||
|
|
||||||
if (model != null) {
|
if (model != null) {
|
||||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||||
}
|
}
|
||||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
|
||||||
|
|
||||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||||
// 버전
|
// 버전
|
||||||
@@ -118,13 +119,6 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NumberExpression<Long> totalUseCnt =
|
|
||||||
modelHyperParamEntity
|
|
||||||
.m1UseCnt
|
|
||||||
.coalesce(0L)
|
|
||||||
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
|
|
||||||
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
|
|
||||||
|
|
||||||
JPAQuery<HyperParamDto.List> query =
|
JPAQuery<HyperParamDto.List> query =
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(
|
.select(
|
||||||
@@ -136,10 +130,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
modelHyperParamEntity.createdDttm,
|
modelHyperParamEntity.createdDttm,
|
||||||
modelHyperParamEntity.lastUsedDttm,
|
modelHyperParamEntity.lastUsedDttm,
|
||||||
modelHyperParamEntity.memo,
|
modelHyperParamEntity.memo,
|
||||||
modelHyperParamEntity.m1UseCnt,
|
modelHyperParamEntity.totalUseCnt))
|
||||||
modelHyperParamEntity.m2UseCnt,
|
|
||||||
modelHyperParamEntity.m3UseCnt,
|
|
||||||
totalUseCnt.as("totalUseCnt")))
|
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(builder);
|
.where(builder);
|
||||||
|
|
||||||
@@ -164,8 +155,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
asc
|
asc
|
||||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||||
|
case "totalUseCnt" ->
|
||||||
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
|
query.orderBy(
|
||||||
|
asc
|
||||||
|
? modelHyperParamEntity.totalUseCnt.asc()
|
||||||
|
: modelHyperParamEntity.totalUseCnt.desc());
|
||||||
|
|
||||||
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
||||||
}
|
}
|
||||||
@@ -187,7 +181,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
|
||||||
return queryFactory
|
return queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
|
|||||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||||
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
|
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
|
||||||
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
|
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
|
||||||
|
import com.querydsl.core.types.Expression;
|
||||||
import com.querydsl.core.types.Projections;
|
import com.querydsl.core.types.Projections;
|
||||||
import com.querydsl.jpa.JPAExpressions;
|
import com.querydsl.jpa.JPAExpressions;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
@@ -59,6 +60,13 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DetailSummary getModelDetailSummary(UUID uuid) {
|
public DetailSummary getModelDetailSummary(UUID uuid) {
|
||||||
|
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
|
||||||
|
|
||||||
|
Expression<UUID> beforeModelUuid =
|
||||||
|
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
|
||||||
|
.from(beforeModel)
|
||||||
|
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
|
||||||
|
|
||||||
return queryFactory
|
return queryFactory
|
||||||
.select(
|
.select(
|
||||||
Projections.constructor(
|
Projections.constructor(
|
||||||
@@ -70,7 +78,8 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
|||||||
modelMasterEntity.step1StrtDttm,
|
modelMasterEntity.step1StrtDttm,
|
||||||
modelMasterEntity.step2EndDttm,
|
modelMasterEntity.step2EndDttm,
|
||||||
modelMasterEntity.statusCd,
|
modelMasterEntity.statusCd,
|
||||||
modelMasterEntity.trainType))
|
modelMasterEntity.trainType,
|
||||||
|
beforeModelUuid))
|
||||||
.from(modelMasterEntity)
|
.from(modelMasterEntity)
|
||||||
.where(modelMasterEntity.uuid.eq(uuid))
|
.where(modelMasterEntity.uuid.eq(uuid))
|
||||||
.fetchOne();
|
.fetchOne();
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.train;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
public interface ModelTrainJobRepositoryCustom {
|
public interface ModelTrainJobRepositoryCustom {
|
||||||
@@ -11,4 +12,6 @@ public interface ModelTrainJobRepositoryCustom {
|
|||||||
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||||
|
|
||||||
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
|
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
|
||||||
|
|
||||||
|
List<ModelTrainJobEntity> findRunningJobs();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.train;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
|
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
|
||||||
|
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.common.enums.JobStatusType;
|
||||||
|
import com.kamco.cd.training.common.enums.JobType;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||||
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||||
import jakarta.persistence.EntityManager;
|
import jakarta.persistence.EntityManager;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
@@ -21,7 +25,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
|||||||
/** modelId의 attempt_no 최대값. (없으면 0) */
|
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||||
@Override
|
@Override
|
||||||
public int findMaxAttemptNo(Long modelId) {
|
public int findMaxAttemptNo(Long modelId) {
|
||||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||||
|
|
||||||
Integer max =
|
Integer max =
|
||||||
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||||
@@ -35,7 +39,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
|||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||||
|
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||||
@@ -45,7 +49,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
QModelTrainJobEntity j = modelTrainJobEntity;
|
||||||
|
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
queryFactory
|
queryFactory
|
||||||
@@ -78,4 +82,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
|||||||
.values(modelId, nextAttemptNo, jobId, epoch)
|
.values(modelId, nextAttemptNo, jobId, epoch)
|
||||||
.execute();
|
.execute();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ModelTrainJobEntity> findRunningJobs() {
|
||||||
|
return queryFactory
|
||||||
|
.select(modelTrainJobEntity)
|
||||||
|
.from(modelTrainJobEntity)
|
||||||
|
.where(
|
||||||
|
modelTrainJobEntity
|
||||||
|
.statusCd
|
||||||
|
.eq(JobStatusType.RUNNING.getId())
|
||||||
|
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
|
||||||
|
.orderBy(modelTrainJobEntity.id.desc())
|
||||||
|
.fetch();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ public class TrainApiController {
|
|||||||
})
|
})
|
||||||
@PostMapping("/create-tmp/{uuid}")
|
@PostMapping("/create-tmp/{uuid}")
|
||||||
public ApiResponseDto<UUID> createTmpFile(
|
public ApiResponseDto<UUID> createTmpFile(
|
||||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
@Parameter(description = "model uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||||
@PathVariable
|
@PathVariable
|
||||||
UUID uuid) {
|
UUID uuid) {
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,17 @@ import java.io.IOException;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/** 학습실행 파일 하드링크 */
|
||||||
@Service
|
@Service
|
||||||
@Log4j2
|
@Log4j2
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@@ -42,6 +46,10 @@ public class DataSetCountersService {
|
|||||||
|
|
||||||
// tmp
|
// tmp
|
||||||
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
|
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
|
||||||
|
|
||||||
|
// 차이나는거
|
||||||
|
diffMergedRequestsVsTmp(uids, tmpPath);
|
||||||
|
|
||||||
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
|
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
|
||||||
allLogs
|
allLogs
|
||||||
.append(counters2.prints(basic.getRequestPath(), "TMP"))
|
.append(counters2.prints(basic.getRequestPath(), "TMP"))
|
||||||
@@ -163,4 +171,58 @@ public class DataSetCountersService {
|
|||||||
test + test2);
|
test + test2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Set<String> listTifRelative(Path root) throws IOException {
|
||||||
|
if (!Files.isDirectory(root)) return Set.of();
|
||||||
|
|
||||||
|
try (var stream = Files.walk(root)) {
|
||||||
|
return stream
|
||||||
|
.filter(Files::isRegularFile)
|
||||||
|
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
|
||||||
|
.map(p -> root.relativize(p).toString().replace("\\", "/"))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Set<String> listTifFileNameOnly(Path root) throws IOException {
|
||||||
|
if (!Files.isDirectory(root)) return Set.of();
|
||||||
|
|
||||||
|
try (var stream = Files.walk(root)) {
|
||||||
|
return stream
|
||||||
|
.filter(Files::isRegularFile)
|
||||||
|
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
|
||||||
|
.map(p -> p.getFileName().toString()) // 파일명만
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void diffMergedRequestsVsTmp(List<String> uids, Path tmpRoot) throws IOException {
|
||||||
|
|
||||||
|
// 1) 요청 uids 전체를 합친 tif "파일명" 집합
|
||||||
|
Set<String> reqAll = new HashSet<>();
|
||||||
|
for (String uid : uids) {
|
||||||
|
Path reqRoot = Path.of(requestDir, uid);
|
||||||
|
|
||||||
|
// ★합본 tmp는 보통 폴더 구조가 바뀌므로 "상대경로" 비교보다 파일명 비교가 먼저 유용합니다.
|
||||||
|
reqAll.addAll(listTifFileNameOnly(reqRoot));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) tmp tif 파일명 집합
|
||||||
|
Set<String> tmpAll = listTifFileNameOnly(tmpRoot);
|
||||||
|
|
||||||
|
Set<String> missing = new HashSet<>(reqAll);
|
||||||
|
missing.removeAll(tmpAll);
|
||||||
|
|
||||||
|
Set<String> extra = new HashSet<>(tmpAll);
|
||||||
|
extra.removeAll(reqAll);
|
||||||
|
|
||||||
|
log.info("==== MERGED DIFF (filename-based) ====");
|
||||||
|
log.info("request(all uids) tif = {}", reqAll.size());
|
||||||
|
log.info("tmp tif = {}", tmpAll.size());
|
||||||
|
log.info("missing = {}", missing.size());
|
||||||
|
log.info("extra = {}", extra.size());
|
||||||
|
|
||||||
|
missing.stream().sorted().limit(50).forEach(f -> log.warn("[MISSING] {}", f));
|
||||||
|
extra.stream().sorted().limit(50).forEach(f -> log.warn("[EXTRA] {}", f));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,14 @@ public class DockerTrainService {
|
|||||||
|
|
||||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
|
|
||||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
/**
|
||||||
|
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||||
|
*
|
||||||
|
* @param req
|
||||||
|
* @param containerName
|
||||||
|
* @return
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||||
|
|
||||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||||
@@ -267,8 +274,7 @@ public class DockerTrainService {
|
|||||||
addArg(c, "--input-size", req.getInputSize());
|
addArg(c, "--input-size", req.getInputSize());
|
||||||
addArg(c, "--crop-size", req.getCropSize());
|
addArg(c, "--crop-size", req.getCropSize());
|
||||||
addArg(c, "--batch-size", req.getBatchSize());
|
addArg(c, "--batch-size", req.getBatchSize());
|
||||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
addArg(c, "--gpu-ids", req.getGpuIds()); // null
|
||||||
// addArg(c, "--gpus", req.getGpus());
|
|
||||||
addArg(c, "--lr", req.getLearningRate());
|
addArg(c, "--lr", req.getLearningRate());
|
||||||
addArg(c, "--backbone", req.getBackbone());
|
addArg(c, "--backbone", req.getBackbone());
|
||||||
addArg(c, "--epochs", req.getEpochs());
|
addArg(c, "--epochs", req.getEpochs());
|
||||||
@@ -342,6 +348,14 @@ public class DockerTrainService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||||
|
*
|
||||||
|
* @param containerName
|
||||||
|
* @param req
|
||||||
|
* @return
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
||||||
|
|
||||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||||
|
|||||||
@@ -0,0 +1,449 @@
|
|||||||
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
|
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||||
|
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||||
|
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.DirectoryStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.log4j.Log4j2;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.context.annotation.Profile;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 서버 재기동 시 "RUNNING 상태로 남아있는 학습 Job"을 복구(정리)하기 위한 서비스.
|
||||||
|
*
|
||||||
|
* <p>상황 예시: - 서버가 강제 재기동/장애로 내려감 - DB 상에서는 job_state가 RUNNING(진행중)으로 남아있음 - 실제 docker 컨테이너는: 1) 아직
|
||||||
|
* 살아있거나(running=true) 2) 종료되었거나(exited) 3) --rm 옵션으로 인해 컨테이너가 이미 삭제되어 존재하지 않을 수 있음
|
||||||
|
*
|
||||||
|
* <p>이 클래스는 ApplicationReadyEvent(스프링 부팅 완료) 시점에 실행되어, DB의 RUNNING 잡들을 조회한 뒤 컨테이너 상태를 점검하고,
|
||||||
|
* SUCCESS/FAILED 처리를 수행합니다.
|
||||||
|
*/
|
||||||
|
@Profile("!local")
|
||||||
|
@Component
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
@Log4j2
|
||||||
|
public class JobRecoveryOnStartupService {
|
||||||
|
|
||||||
|
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||||
|
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
|
||||||
|
*
|
||||||
|
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
|
||||||
|
*/
|
||||||
|
@Value("${train.docker.responseDir}")
|
||||||
|
private String responseDir;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
|
||||||
|
*
|
||||||
|
* <p>@Transactional: - recover() 메서드 전체가 하나의 트랜잭션으로 감싸집니다. - Job 하나씩 처리하다가 예외가 발생하면 전체 롤백이 될 수
|
||||||
|
* 있으므로 "잡 단위로 확실히 커밋"이 필요하면 (권장) 잡 단위로 분리 트랜잭션(REQUIRES_NEW) 고려하세요.
|
||||||
|
*/
|
||||||
|
// @EventListener(ApplicationReadyEvent.class)
|
||||||
|
@Transactional
|
||||||
|
public void recover() {
|
||||||
|
|
||||||
|
// 1) DB에서 "RUNNING(진행중) 상태"로 남아있는 job 목록을 조회
|
||||||
|
List<ModelTrainJobDto> runningJobs = modelTrainJobCoreService.findRunningJobs();
|
||||||
|
|
||||||
|
// 실행중 job이 없으면 할 일 없음
|
||||||
|
if (runningJobs == null || runningJobs.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 각 job에 대해 docker 컨테이너 상태를 확인하고, 상태에 따라 조치
|
||||||
|
for (ModelTrainJobDto job : runningJobs) {
|
||||||
|
String containerName = job.getContainerName();
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 2-1) docker inspect로 컨테이너 상태 조회
|
||||||
|
DockerInspectState state = inspectContainer(containerName);
|
||||||
|
|
||||||
|
// 3) 컨테이너가 "없음"
|
||||||
|
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
|
||||||
|
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
|
||||||
|
if (!state.exists()) {
|
||||||
|
log.warn(
|
||||||
|
"[RECOVERY] container missing. try file-based reconcile. container={}",
|
||||||
|
containerName);
|
||||||
|
|
||||||
|
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
|
||||||
|
OutputResult out = probeOutputs(job);
|
||||||
|
|
||||||
|
// 3-2) 산출물이 충분하면 성공 처리
|
||||||
|
if (out.completed()) {
|
||||||
|
log.info("[RECOVERY] outputs look completed. mark SUCCESS. jobId={}", job.getId());
|
||||||
|
modelTrainJobCoreService.markSuccess(job.getId(), 0);
|
||||||
|
markStepSuccessByJobType(job);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// 3-3) 산출물이 부족하면 실패 처리(운영 정책에 따라 "유예"도 가능)
|
||||||
|
log.warn(
|
||||||
|
"[RECOVERY] outputs incomplete. mark FAILED. jobId={} reason={}",
|
||||||
|
job.getId(),
|
||||||
|
out.reason());
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markFailed(
|
||||||
|
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
|
||||||
|
|
||||||
|
markStepErrorByJobType(job, out.reason());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 컨테이너는 존재하고, 아직 running=true
|
||||||
|
// - 서버만 재기동됐고 컨테이너는 그대로 살아있는 케이스
|
||||||
|
// - 이 경우 DB를 건드리면 오히려 꼬일 수 있으니 RUNNING 유지
|
||||||
|
if (state.running()) {
|
||||||
|
log.info("[RECOVERY] container still running. container={}", containerName);
|
||||||
|
try {
|
||||||
|
ProcessBuilder pb = new ProcessBuilder("docker", "stop", "-t", "20", containerName);
|
||||||
|
pb.redirectErrorStream(true);
|
||||||
|
|
||||||
|
Process p = pb.start();
|
||||||
|
|
||||||
|
boolean finished = p.waitFor(30, TimeUnit.SECONDS);
|
||||||
|
if (!finished) {
|
||||||
|
p.destroyForcibly();
|
||||||
|
throw new IOException("docker stop timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
int code = p.exitValue();
|
||||||
|
if (code != 0) {
|
||||||
|
throw new IOException("docker stop failed. exit=" + code);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"[RECOVERY] container stopped (will be auto removed by --rm). container={}",
|
||||||
|
containerName);
|
||||||
|
|
||||||
|
// 여기서 상태를 PAUSED로 바꿔도 되고
|
||||||
|
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("[RECOVERY] docker stop failed. container={}", containerName, e);
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markFailed(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 컨테이너는 존재하지만 running=false
|
||||||
|
// - exited / dead 등의 상태
|
||||||
|
Integer exitCode = state.exitCode();
|
||||||
|
String status = state.status();
|
||||||
|
|
||||||
|
// 5-1) exitCode=0이면 정상 종료로 간주 → SUCCESS 처리
|
||||||
|
if (exitCode != null && exitCode == 0) {
|
||||||
|
log.info("[RECOVERY] container exited(0). mark SUCCESS. container={}", containerName);
|
||||||
|
modelTrainJobCoreService.markSuccess(job.getId(), 0);
|
||||||
|
markStepSuccessByJobType(job);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// 5-2) exitCode != 0 이거나 null이면 실패로 간주 → FAILED 처리
|
||||||
|
log.warn(
|
||||||
|
"[RECOVERY] container exited non-zero. mark FAILED. container={} status={} exitCode={}",
|
||||||
|
containerName,
|
||||||
|
status,
|
||||||
|
exitCode);
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markFailed(
|
||||||
|
job.getId(), exitCode, "SERVER_RESTART_CONTAINER_EXIT_NONZERO");
|
||||||
|
|
||||||
|
markStepErrorByJobType(job, "exit=" + exitCode + " status=" + status);
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
// 6) docker inspect 자체가 실패한 경우
|
||||||
|
// - docker 데몬 문제/권한 문제/일시적 오류 가능
|
||||||
|
// - 운영 정책에 따라 "바로 실패" 대신 "유예" 처리도 고려 가능
|
||||||
|
log.error("[RECOVERY] container inspect failed. container={}", containerName, e);
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markFailed(
|
||||||
|
job.getId(), -1, "SERVER_RESTART_CONTAINER_INSPECT_ERROR");
|
||||||
|
|
||||||
|
markStepErrorByJobType(job, "inspect-error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* jobType에 따라 학습 관리 테이블의 "성공 단계"를 업데이트.
|
||||||
|
*
|
||||||
|
* <p>예: - jobType == "EVAL" → step2(평가 단계) 성공 - 그 외 → step1(학습 단계) 성공
|
||||||
|
*/
|
||||||
|
private void markStepSuccessByJobType(ModelTrainJobDto job) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
|
||||||
|
if (isEval) {
|
||||||
|
modelTrainMngCoreService.markStep2Success(job.getModelId());
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markStep1Success(job.getModelId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
|
||||||
|
*
|
||||||
|
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
|
||||||
|
*/
|
||||||
|
private void markStepErrorByJobType(ModelTrainJobDto job, String msg) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
|
||||||
|
if (isEval) {
|
||||||
|
modelTrainMngCoreService.markStep2Error(job.getModelId(), msg);
|
||||||
|
} else {
|
||||||
|
modelTrainMngCoreService.markError(job.getModelId(), msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
|
||||||
|
*
|
||||||
|
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
|
||||||
|
*
|
||||||
|
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
|
||||||
|
*
|
||||||
|
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
|
||||||
|
* 타임아웃을 둠
|
||||||
|
*/
|
||||||
|
private DockerInspectState inspectContainer(String containerName)
|
||||||
|
throws IOException, InterruptedException {
|
||||||
|
|
||||||
|
ProcessBuilder pb =
|
||||||
|
new ProcessBuilder(
|
||||||
|
"docker",
|
||||||
|
"inspect",
|
||||||
|
"-f",
|
||||||
|
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
|
||||||
|
containerName);
|
||||||
|
|
||||||
|
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
|
||||||
|
pb.redirectErrorStream(true);
|
||||||
|
|
||||||
|
Process p = pb.start();
|
||||||
|
|
||||||
|
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
|
||||||
|
String output;
|
||||||
|
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
|
||||||
|
output = br.readLine();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
|
||||||
|
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
|
||||||
|
if (!finished) {
|
||||||
|
p.destroyForcibly();
|
||||||
|
throw new IOException("docker inspect timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
// docker inspect 자체의 프로세스 exit code
|
||||||
|
int code = p.exitValue();
|
||||||
|
|
||||||
|
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
|
||||||
|
if (code != 0 || output == null || output.isBlank()) {
|
||||||
|
return DockerInspectState.missing();
|
||||||
|
}
|
||||||
|
|
||||||
|
// "status running exitCode" 형태로 split
|
||||||
|
String[] parts = output.trim().split("\\s+");
|
||||||
|
|
||||||
|
// status: running/exited/dead 등
|
||||||
|
String status = parts.length > 0 ? parts[0] : "unknown";
|
||||||
|
|
||||||
|
// running: true/false
|
||||||
|
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
|
||||||
|
|
||||||
|
// exitCode: 정수 파싱(파싱 실패하면 null)
|
||||||
|
Integer exitCode = null;
|
||||||
|
if (parts.length > 2) {
|
||||||
|
try {
|
||||||
|
exitCode = Integer.parseInt(parts[2]);
|
||||||
|
} catch (Exception ignore) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DockerInspectState(true, running, exitCode, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* docker inspect 결과를 담는 레코드.
|
||||||
|
*
|
||||||
|
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
|
||||||
|
*/
|
||||||
|
private record DockerInspectState(
|
||||||
|
boolean exists, boolean running, Integer exitCode, String status) {
|
||||||
|
static DockerInspectState missing() {
|
||||||
|
return new DockerInspectState(false, false, null, "missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================================
|
||||||
|
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
|
||||||
|
// ============================================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
|
||||||
|
*
|
||||||
|
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
|
||||||
|
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
|
||||||
|
*
|
||||||
|
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
|
||||||
|
* 실패로 오판할 수 있음.
|
||||||
|
*/
|
||||||
|
private OutputResult probeOutputs(ModelTrainJobDto job) {
|
||||||
|
try {
|
||||||
|
Path outDir = resolveOutputDir(job);
|
||||||
|
if (outDir == null || !Files.isDirectory(outDir)) {
|
||||||
|
return new OutputResult(false, "output-dir-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
|
||||||
|
if (totalEpoch == null || totalEpoch <= 0) {
|
||||||
|
return new OutputResult(false, "total-epoch-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
Path valCsv = outDir.resolve("val.csv");
|
||||||
|
if (!Files.exists(valCsv)) {
|
||||||
|
return new OutputResult(false, "val.csv-missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
long lines = countNonHeaderLines(valCsv);
|
||||||
|
|
||||||
|
// “같아야 완료” 정책
|
||||||
|
if (lines == totalEpoch) {
|
||||||
|
return new OutputResult(true, "ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
return new OutputResult(
|
||||||
|
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
|
||||||
|
return new OutputResult(false, "probe-error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
|
||||||
|
*
|
||||||
|
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
|
||||||
|
*
|
||||||
|
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
|
||||||
|
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
|
||||||
|
*
|
||||||
|
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
|
||||||
|
* 형태를 1순위로 두세요.
|
||||||
|
*/
|
||||||
|
private Path resolveOutputDir(ModelTrainJobDto job) {
|
||||||
|
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
|
||||||
|
|
||||||
|
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
|
||||||
|
|
||||||
|
return Files.isDirectory(base) ? base : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* paramsJson에서 total_epoch 값을 추출합니다.
|
||||||
|
*
|
||||||
|
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
|
||||||
|
*
|
||||||
|
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
|
||||||
|
*/
|
||||||
|
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
|
||||||
|
Map<String, Object> params = job.getParamsJson();
|
||||||
|
if (params == null) return Optional.empty();
|
||||||
|
|
||||||
|
Object v = params.get("total_epoch");
|
||||||
|
if (v == null) v = params.get("totalEpoch");
|
||||||
|
if (v == null) return Optional.empty();
|
||||||
|
|
||||||
|
try {
|
||||||
|
return Optional.of(Integer.parseInt(String.valueOf(v)));
|
||||||
|
} catch (Exception ignore) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
|
||||||
|
*
|
||||||
|
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
|
||||||
|
*
|
||||||
|
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
|
||||||
|
*/
|
||||||
|
private long countNonHeaderLines(Path csv) throws IOException {
|
||||||
|
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
|
||||||
|
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
|
||||||
|
*
|
||||||
|
* <p>예: - "*.pth" - "best*.pth"
|
||||||
|
*/
|
||||||
|
private long countFilesByGlob(Path dir, String glob) throws IOException {
|
||||||
|
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
|
||||||
|
long cnt = 0;
|
||||||
|
for (Path p : ds) {
|
||||||
|
if (Files.isRegularFile(p)) cnt++;
|
||||||
|
}
|
||||||
|
return cnt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
|
||||||
|
private boolean existsByGlob(Path dir, String glob) throws IOException {
|
||||||
|
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
|
||||||
|
return ds.iterator().hasNext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================================
|
||||||
|
// probeOutputs() 결과 객체
|
||||||
|
// ============================================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
|
||||||
|
*
|
||||||
|
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
|
||||||
|
*
|
||||||
|
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
|
||||||
|
*/
|
||||||
|
private static final class OutputResult {
|
||||||
|
|
||||||
|
private final boolean completed;
|
||||||
|
private final String reason;
|
||||||
|
|
||||||
|
private OutputResult(boolean completed, String reason) {
|
||||||
|
this.completed = completed;
|
||||||
|
this.reason = reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean completed() {
|
||||||
|
return completed;
|
||||||
|
}
|
||||||
|
|
||||||
|
String reason() {
|
||||||
|
return reason;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
|
|||||||
@Value("${file.pt-path}")
|
@Value("${file.pt-path}")
|
||||||
private String ptPathDir;
|
private String ptPathDir;
|
||||||
|
|
||||||
/**
|
/** 결과 csv 파일 정보 등록 */
|
||||||
* 실행중인 profile
|
public void findTestValidMetricCsvFiles() {
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private boolean isLocalProfile() {
|
|
||||||
return "local".equalsIgnoreCase(profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
// @Scheduled(cron = "0 * * * * *")
|
|
||||||
public void findTestValidMetricCsvFiles() throws IOException {
|
|
||||||
// if (isLocalProfile()) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
List<ResponsePathDto> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||||
|
|||||||
@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
|
|||||||
@Value("${train.docker.responseDir}")
|
@Value("${train.docker.responseDir}")
|
||||||
private String responseDir;
|
private String responseDir;
|
||||||
|
|
||||||
/**
|
/** 결과 csv 파일 정보 등록 */
|
||||||
* 실행중인 profile
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private boolean isLocalProfile() {
|
|
||||||
return "local".equalsIgnoreCase(profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
// @Scheduled(cron = "0 * * * * *")
|
|
||||||
public void findTrainValidMetricCsvFiles() {
|
public void findTrainValidMetricCsvFiles() {
|
||||||
// if (isLocalProfile()) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
List<ResponsePathDto> modelIds =
|
List<ResponsePathDto> modelIds =
|
||||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||||
|
|||||||
@@ -23,6 +23,14 @@ public class TestJobService {
|
|||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
private final DataSetCountersService dataSetCounters;
|
private final DataSetCountersService dataSetCounters;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 실행 예약 (QUEUE 등록)
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @param uuid
|
||||||
|
* @param epoch
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||||
|
|
||||||
@@ -58,6 +66,11 @@ public class TestJobService {
|
|||||||
return jobId;
|
return jobId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 취소
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void cancel(Long modelId) {
|
public void cancel(Long modelId) {
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ public class TmpDatasetService {
|
|||||||
* @param uid 임시폴더 uuid
|
* @param uid 임시폴더 uuid
|
||||||
* @param type train, val, test
|
* @param type train, val, test
|
||||||
* @param links tif pull path
|
* @param links tif pull path
|
||||||
|
* @return
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public void buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
|
public void buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
|
||||||
|
|||||||
@@ -47,7 +47,12 @@ public class TrainJobService {
|
|||||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 실행 예약 (QUEUE 등록) */
|
/**
|
||||||
|
* 실행 예약 (QUEUE 등록)
|
||||||
|
*
|
||||||
|
* @param modelId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public Long enqueue(Long modelId) {
|
public Long enqueue(Long modelId) {
|
||||||
|
|
||||||
@@ -139,6 +144,13 @@ public class TrainJobService {
|
|||||||
modelTrainMngCoreService.markStopped(modelId);
|
modelTrainMngCoreService.markStopped(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 이어하기
|
||||||
|
*
|
||||||
|
* @param modelId 모델 id
|
||||||
|
* @param mode NONE 새로 시작, REQUIRE 이어하기
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||||
|
|
||||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||||
@@ -189,6 +201,12 @@ public class TrainJobService {
|
|||||||
REQUIRE // 이어하기
|
REQUIRE // 이어하기
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
|
||||||
|
*
|
||||||
|
* @param paramsJson
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||||
if (paramsJson == null) return null;
|
if (paramsJson == null) return null;
|
||||||
|
|
||||||
@@ -230,6 +248,12 @@ public class TrainJobService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
|
||||||
|
*
|
||||||
|
* @param modelUuid
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public UUID createTmpFile(UUID modelUuid) {
|
public UUID createTmpFile(UUID modelUuid) {
|
||||||
UUID tmpUuid = UUID.randomUUID();
|
UUID tmpUuid = UUID.randomUUID();
|
||||||
@@ -242,6 +266,8 @@ public class TrainJobService {
|
|||||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// 데이터셋 심볼링크 생성
|
||||||
|
// String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||||
// train path
|
// train path
|
||||||
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
|
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
|
||||||
// validation path
|
// validation path
|
||||||
@@ -272,14 +298,8 @@ public class TrainJobService {
|
|||||||
e);
|
e);
|
||||||
|
|
||||||
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
|
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
|
||||||
throw new IllegalStateException(
|
throw new CustomApiException(
|
||||||
"tmp dataset build failed: modelUuid="
|
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
|
||||||
+ modelUuid
|
|
||||||
+ ", modelId="
|
|
||||||
+ modelId
|
|
||||||
+ ", tmpRaw="
|
|
||||||
+ raw,
|
|
||||||
e);
|
|
||||||
}
|
}
|
||||||
return modelUuid;
|
return modelUuid;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
|
|||||||
import org.springframework.transaction.event.TransactionPhase;
|
import org.springframework.transaction.event.TransactionPhase;
|
||||||
import org.springframework.transaction.event.TransactionalEventListener;
|
import org.springframework.transaction.event.TransactionalEventListener;
|
||||||
|
|
||||||
|
/** job 실행 */
|
||||||
@Log4j2
|
@Log4j2
|
||||||
@Component
|
@Component
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@@ -29,10 +30,12 @@ public class TrainJobWorker {
|
|||||||
private final DockerTrainService dockerTrainService;
|
private final DockerTrainService dockerTrainService;
|
||||||
private final ObjectMapper objectMapper;
|
private final ObjectMapper objectMapper;
|
||||||
|
|
||||||
@Async
|
@Async("trainJobExecutor")
|
||||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||||
public void handle(ModelTrainJobQueuedEvent event) {
|
public void handle(ModelTrainJobQueuedEvent event) {
|
||||||
|
|
||||||
|
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
|
||||||
|
|
||||||
Long jobId = event.getJobId();
|
Long jobId = event.getJobId();
|
||||||
|
|
||||||
ModelTrainJobDto job =
|
ModelTrainJobDto job =
|
||||||
@@ -54,6 +57,8 @@ public class TrainJobWorker {
|
|||||||
String containerName =
|
String containerName =
|
||||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||||
|
|
||||||
|
String type = isEval ? "TEST" : "TRAIN";
|
||||||
|
|
||||||
Integer totalEpoch = null;
|
Integer totalEpoch = null;
|
||||||
if (params.containsKey("totalEpoch")) {
|
if (params.containsKey("totalEpoch")) {
|
||||||
if (params.get("totalEpoch") != null) {
|
if (params.get("totalEpoch") != null) {
|
||||||
@@ -61,12 +66,15 @@ public class TrainJobWorker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
||||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
// 실행 시작 처리
|
||||||
|
modelTrainJobCoreService.markRunning(
|
||||||
|
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
|
||||||
log.info("[JOB] markRunning done jobId={}", jobId);
|
log.info("[JOB] markRunning done jobId={}", jobId);
|
||||||
try {
|
try {
|
||||||
TrainRunResult result;
|
TrainRunResult result;
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// step2 진행중 처리
|
||||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||||
String uuid = String.valueOf(params.get("uuid"));
|
String uuid = String.valueOf(params.get("uuid"));
|
||||||
int epoch = (int) params.get("epoch");
|
int epoch = (int) params.get("epoch");
|
||||||
@@ -81,11 +89,13 @@ public class TrainJobWorker {
|
|||||||
evalReq.setOutputFolder(outputFolder);
|
evalReq.setOutputFolder(outputFolder);
|
||||||
log.info("[JOB] selected test epoch={}", epoch);
|
log.info("[JOB] selected test epoch={}", epoch);
|
||||||
|
|
||||||
|
// 도커 실행 후 로그 수집
|
||||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
// step1 진행중 처리
|
||||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||||
|
// 도커 실행 후 로그 수집
|
||||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,24 +109,31 @@ public class TrainJobWorker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (result.getExitCode() == 0) {
|
if (result.getExitCode() == 0) {
|
||||||
|
// 성공 처리
|
||||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// step2 완료처리
|
||||||
modelTrainMngCoreService.markStep2Success(modelId);
|
modelTrainMngCoreService.markStep2Success(modelId);
|
||||||
|
// 결과 csv 파일 정보 등록
|
||||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||||
} else {
|
} else {
|
||||||
modelTrainMngCoreService.markStep1Success(modelId);
|
modelTrainMngCoreService.markStep1Success(modelId);
|
||||||
|
// 결과 csv 파일 정보 등록
|
||||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
||||||
|
// 실패 처리
|
||||||
modelTrainJobCoreService.markFailed(
|
modelTrainJobCoreService.markFailed(
|
||||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||||
|
|
||||||
if (isEval) {
|
if (isEval) {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||||
} else {
|
} else {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,8 +142,10 @@ public class TrainJobWorker {
|
|||||||
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
||||||
|
|
||||||
if ("EVAL".equals(params.get("jobType"))) {
|
if ("EVAL".equals(params.get("jobType"))) {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||||
} else {
|
} else {
|
||||||
|
// 오류 정보 등록
|
||||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user