44 Commits

Author SHA1 Message Date
6a2deff93b Merge pull request '모델학습관리 > 목록 API 메모,작성자 추가로 인한 수정' (#118) from feat/training_260202 into develop
Reviewed-on: #118
2026-02-19 15:35:03 +09:00
b0a99afcd3 Merge pull request '모델학습 2단계 패키징 시작,종료일시,상태 로직 추가' (#117) from feat/training_260202 into develop
Reviewed-on: #117
2026-02-19 14:44:12 +09:00
eedf72d7aa Merge pull request '공통코드 common-code 로 prefix 변경' (#116) from feat/training_260202 into develop
Reviewed-on: #116
2026-02-19 11:39:23 +09:00
25e9941464 Merge pull request '로그관리 로직 커밋' (#115) from feat/training_260202 into develop
Reviewed-on: #115
2026-02-19 11:16:40 +09:00
a0da0392cf Merge pull request '압축해제 시, 동일 폴더가 있으면 삭제 후 재업로드' (#114) from feat/training_260202 into develop
Reviewed-on: #114
2026-02-18 16:37:13 +09:00
a3ebee12b5 Merge pull request 'feat/training_260202' (#113) from feat/training_260202 into develop
Reviewed-on: #113
2026-02-18 16:29:55 +09:00
c5b14ca09d Merge pull request '업로드 시 uid로 중복체크 -> 삭제인 row는 제외하기' (#112) from feat/training_260202 into develop
Reviewed-on: #112
2026-02-18 15:40:38 +09:00
44b3b857b1 Merge pull request 'feat/training_260202' (#111) from feat/training_260202 into develop
Reviewed-on: #111
2026-02-18 15:29:25 +09:00
63124455fd Merge pull request '1단계 실행 시, 시작시간 update 추가' (#110) from feat/training_260202 into develop
Reviewed-on: #110
2026-02-18 13:06:33 +09:00
e75ea8d8a5 Merge pull request '하이퍼 파라미터 수정' (#109) from feat/training_260202 into develop
Reviewed-on: #109
2026-02-13 15:00:50 +09:00
31ac4209c3 Merge pull request '하이퍼 파라미터 수정' (#108) from feat/training_260202 into develop
Reviewed-on: #108
2026-02-13 14:47:19 +09:00
df09935789 Merge pull request '하이퍼 파라미터 수정' (#107) from feat/training_260202 into develop
Reviewed-on: #107
2026-02-13 14:42:54 +09:00
bb15b1b0f2 Merge pull request '하이퍼 파라미터 수정' (#106) from feat/training_260202 into develop
Reviewed-on: #106
2026-02-13 14:38:01 +09:00
f4d491ed94 Merge pull request '하이퍼 파라미터 수정' (#105) from feat/training_260202 into develop
Reviewed-on: #105
2026-02-13 14:31:13 +09:00
96cb7d2f23 Merge pull request '사용가능 용량 API 수정' (#104) from feat/training_260202 into develop
Reviewed-on: #104
2026-02-13 14:19:26 +09:00
cc6305b0df Merge pull request '이어하기 수정' (#103) from feat/training_260202 into develop
Reviewed-on: #103
2026-02-13 14:08:51 +09:00
3916b13876 Merge pull request '이어하기 수정' (#102) from feat/training_260202 into develop
Reviewed-on: #102
2026-02-13 14:04:52 +09:00
ee4a06df30 Merge pull request '이어하기 수정' (#101) from feat/training_260202 into develop
Reviewed-on: #101
2026-02-13 13:57:48 +09:00
bb5ff7c3cd Merge pull request '이어하기 로그 수정' (#100) from feat/training_260202 into develop
Reviewed-on: #100
2026-02-13 13:27:08 +09:00
312a96dda1 Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#99) from feat/training_260202 into develop
Reviewed-on: #99
2026-02-13 13:24:10 +09:00
e38231e06d Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#98) from feat/training_260202 into develop
Reviewed-on: #98
2026-02-13 13:17:41 +09:00
bf6e45d706 Merge pull request 'feat/training_260202' (#97) from feat/training_260202 into develop
Reviewed-on: #97
2026-02-13 12:53:32 +09:00
7fa8921a25 Merge pull request 'flush 추가해보기' (#96) from feat/training_260202 into develop
Reviewed-on: #96
2026-02-13 12:47:11 +09:00
7c940351d9 Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#95) from feat/training_260202 into develop
Reviewed-on: #95
2026-02-13 12:39:21 +09:00
6b834da912 Merge pull request 'feat/training_260202' (#94) from feat/training_260202 into develop
Reviewed-on: #94
2026-02-13 12:25:26 +09:00
25aaa97d65 Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#93) from feat/training_260202 into develop
Reviewed-on: #93
2026-02-13 12:21:11 +09:00
da9d47ae4a Merge pull request '주석 처리' (#92) from feat/training_260202 into develop
Reviewed-on: #92
2026-02-13 12:10:46 +09:00
7d6a77bf2a Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#91) from feat/training_260202 into develop
Reviewed-on: #91
2026-02-13 12:01:01 +09:00
26828d0968 add log 2026-02-13 12:00:54 +09:00
e2dbae15c0 Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#90) from feat/training_260202 into develop
Reviewed-on: #90
2026-02-13 11:58:55 +09:00
b246034632 Merge pull request '트랜젝션처리 임시폴더 uid업데이트' (#89) from feat/training_260202 into develop
Reviewed-on: #89
2026-02-13 11:58:25 +09:00
687ea82d78 Merge pull request 'feat/training_260202' (#88) from feat/training_260202 into develop
Reviewed-on: #88
2026-02-13 10:51:07 +09:00
4ac0f19908 Merge pull request '파일 count 기능 추가' (#87) from feat/training_260202 into develop
Reviewed-on: #87
2026-02-13 10:38:48 +09:00
9e5e7595eb Merge pull request '학습실행 step1 할 때 best epoch 업데이트' (#86) from feat/training_260202 into develop
Reviewed-on: #86
2026-02-13 10:18:26 +09:00
9cd9274e99 Merge pull request '학습데이터 목록 파일 단위 MB 나오게 하기' (#85) from feat/training_260202 into develop
Reviewed-on: #85
2026-02-13 09:48:08 +09:00
5d82f3ecfe Merge pull request 'tmp 파일 링크 수정' (#84) from feat/training_260202 into develop
Reviewed-on: #84
2026-02-13 09:11:05 +09:00
2ce249ab33 Merge pull request 'tmp 파일 링크 수정' (#83) from feat/training_260202 into develop
Reviewed-on: #83
2026-02-13 08:44:37 +09:00
e34bf68de0 Merge pull request 'tmp 파일 링크 수정' (#82) from feat/training_260202 into develop
Reviewed-on: #82
2026-02-13 08:33:58 +09:00
862bda0cb9 Merge pull request '이어하기 수정' (#81) from feat/training_260202 into develop
Reviewed-on: #81
2026-02-12 23:02:12 +09:00
90f7b17d07 Merge pull request '학습데이터 다운로드 파일 정보 API 추가' (#80) from feat/training_260202 into develop
Reviewed-on: #80
2026-02-12 22:58:47 +09:00
2128baa46a Merge pull request 'feat/training_260202' (#79) from feat/training_260202 into develop
Reviewed-on: #79
2026-02-12 22:26:26 +09:00
875c30f467 Merge pull request 'feat/training_260202' (#78) from feat/training_260202 into develop
Reviewed-on: #78
2026-02-12 21:52:16 +09:00
2b29cd1ac6 Merge pull request '파라미터 변경' (#77) from feat/training_260202 into develop
Reviewed-on: #77
2026-02-12 21:30:18 +09:00
9206fff5d0 Merge pull request 'feat/training_260202' (#76) from feat/training_260202 into develop
Reviewed-on: #76
2026-02-12 21:17:33 +09:00
54 changed files with 247 additions and 2028 deletions

View File

@@ -2,8 +2,10 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication
@EnableScheduling
public class KamcoTrainingApplication {

View File

@@ -23,8 +23,7 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private final UserDetailsService userDetailsService;
private static final AntPathMatcher PATH_MATCHER = new AntPathMatcher();
private static final String[] EXCLUDE_PATHS = {
// "/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password"
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout"
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password"
};
@Override

View File

@@ -12,14 +12,13 @@ import lombok.Setter;
@AllArgsConstructor
@NoArgsConstructor
public class HyperParam {
@Schema(description = "모델", example = "G1")
private ModelType model; // G1, G2, G3
// -------------------------
// Important
// -------------------------
@Schema(description = "모델", example = "large")
private ModelType model; // backbone
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone

View File

@@ -1,27 +0,0 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobStatusType implements EnumType {
QUEUED("대기중"),
RUNNING("실행중"),
SUCCESS("성공"),
FAILED("실패"),
CANCELED("취소");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -1,24 +0,0 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobType implements EnumType {
TRAIN("학습"),
TEST("테스트");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -3,6 +3,7 @@ package com.kamco.cd.training.common.utils;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.kamco.cd.training.common.exception.CustomApiException;
@@ -719,26 +720,18 @@ public class FIleChecker {
public static void unzip(String fileName, String destDirectory) throws IOException {
String zipFilePath = destDirectory + File.separator + fileName;
log.info("fileName : {}", fileName);
log.info("destDirectory : {}", destDirectory);
log.info("zipFilePath : {}", zipFilePath);
// zip 이름으로 폴더 생성 (확장자 제거)
String folderName =
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
log.info("folderName : {}", folderName);
File destDir = new File(destDirectory, folderName);
log.info("destDir : {}", destDir);
// 동일 폴더가 이미 있으면 삭제
log.info("111 destDir.exists() : {}", destDir.exists());
if (destDir.exists()) {
deleteDirectoryRecursively(destDir.toPath());
}
log.info("222 destDir.exists() : {}", destDir.exists());
if (!destDir.exists()) {
log.info("mkdirs : {}", destDir.exists());
destDir.mkdirs();
}
@@ -794,6 +787,92 @@ public class FIleChecker {
return destFile;
}
public static void uploadTo86(Path localFile) {
String host = "192.168.2.86";
int port = 22;
String username = "kcomu";
String password = "Kamco2025!";
String remoteDir = "/home/kcomu/data/request";
Session session = null;
ChannelSftp channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(username, host, port);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
channel = (ChannelSftp) session.openChannel("sftp");
channel.connect(10_000);
// 목적지 디렉토리 이동
channel.cd(remoteDir);
// 업로드
channel.put(localFile.toString(), localFile.getFileName().toString());
} catch (Exception e) {
throw new RuntimeException("SFTP upload failed", e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static void unzipOn86Server(String zipPath, String targetDir) {
String host = "192.168.2.86";
String user = "kcomu";
String password = "Kamco2025!";
Session session = null;
ChannelExec channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(user, host, 22);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
String command = "unzip -o " + zipPath + " -d " + targetDir;
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setErrStream(System.err);
InputStream in = channel.getInputStream();
channel.connect();
// 출력 읽기(선택)
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
while (br.readLine() != null) {
// 필요하면 로그
}
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static List<String> execCommandAndReadLines(String command) {
List<String> result = new ArrayList<>();

View File

@@ -1,23 +0,0 @@
package com.kamco.cd.training.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "trainJobExecutor")
public Executor trainJobExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); // 동시에 4개 실행
executor.setMaxPoolSize(8); // 최대 8개
executor.setQueueCapacity(200); // 대기 큐
executor.setThreadNamePrefix("train-job-");
executor.initialize();
return executor;
}
}

View File

@@ -76,13 +76,13 @@ public class SecurityConfig {
"/api/auth/logout",
"/swagger-ui/**",
"/v3/api-docs/**",
"/api/members/*/password",
"/api/upload/chunk-upload-dataset",
"/api/upload/chunk-upload-complete",
"/download_progress_test.html",
"/api/models/download/**")
.permitAll()
.requestMatchers("/api/members/*/password")
.authenticated()
// default
.anyRequest()
.authenticated())

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType;
@@ -233,6 +234,7 @@ public class DatasetDto {
@Getter
@Setter
@NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입
@@ -315,183 +317,6 @@ public class DatasetDto {
}
}
@Schema(name = "SelectTransferDataSet", description = "전이학습 데이터셋 선택 리스트")
@Getter
@Setter
@NoArgsConstructor
public static class SelectTransferDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
private String title;
private Long roundNo;
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
private String beforeModelNo; // G1, G2, G3 모델 타입
private Long beforeDatasetId;
private UUID beforeUuid;
private String beforeDataType;
private String beforeTitle;
private Long beforeRoundNo;
private Integer beforeCompareYyyy;
private Integer beforeTargetYyyy;
private String beforeMemo;
@JsonIgnore private Long beforeClassCount;
private Integer beforeBuildingCnt;
private Integer beforeContainerCnt;
private String beforeDataTypeName;
private Long beforeWasteCnt;
private Long beforeLandCoverCnt;
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Long classCount,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Long beforeClassCount) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo != null && modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo != null && modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeClassCount = beforeClassCount;
if (beforeModelNo != null && beforeModelNo.equals(ModelType.G2.getId())) {
this.beforeWasteCnt = beforeClassCount;
} else if (beforeModelNo != null && beforeModelNo.equals(ModelType.G3.getId())) {
this.beforeLandCoverCnt = beforeClassCount;
}
}
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCnt,
Integer containerCnt,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Integer beforeBuildingCnt,
Integer beforeContainerCnt) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeBuildingCnt = beforeBuildingCnt;
this.beforeContainerCnt = beforeContainerCnt;
}
public String getDataTypeName(String groupTitleCd) {
LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd);
return type == null ? null : type.getText();
}
public String getYear() {
return this.compareYyyy + "-" + this.targetYyyy;
}
public String getBeforeYear() {
if (this.beforeCompareYyyy == null || this.beforeTargetYyyy == null) {
return null;
}
return this.beforeCompareYyyy + "-" + this.beforeTargetYyyy;
}
}
@Getter
@Setter
@NoArgsConstructor

View File

@@ -417,7 +417,6 @@ public class DatasetService {
}
private String escape(String path) {
// 쉘 커맨드에서 안전하게 사용할 수 있도록 문자열을 작은따옴표로 감싸면서, 내부의 작은따옴표를 이스케이프 처리
return "'" + path.replace("'", "'\"'\"'") + "'";
}

View File

@@ -97,9 +97,9 @@ public class HyperParamApiController {
String type,
@Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false)
LocalDate startDate,
@Parameter(description = "종료일", example = "2026-03-31") @RequestParam(required = false)
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false)
LocalDate endDate,
@Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false)
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
String hyperVer,
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
@RequestParam(required = false)
@@ -142,7 +142,7 @@ public class HyperParamApiController {
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
@PathVariable
UUID uuid) {
hyperParamService.deleteHyperParam(uuid);
@@ -164,7 +164,7 @@ public class HyperParamApiController {
})
@GetMapping("/{uuid}")
public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));

View File

@@ -29,8 +29,6 @@ public class HyperParamDto {
private UUID uuid;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createdDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Integer totalUseCnt;
// -------------------------
// Important
@@ -117,7 +115,10 @@ public class HyperParamDto {
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private String memo;
private Integer totalUseCnt;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
private Long totalCnt;
}
@Getter

View File

@@ -1,6 +1,5 @@
package com.kamco.cd.training.log.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
@@ -78,7 +77,6 @@ public class ErrorLogDto {
}
}
@CodeExpose
public enum LogErrorLevel implements EnumType {
WARNING("Warning"),
ERROR("Error"),

View File

@@ -11,7 +11,6 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.model.service.ModelTrainDetailService;
import com.kamco.cd.training.model.service.ModelTrainMngService;
import io.swagger.v3.oas.annotations.Operation;
@@ -134,26 +133,26 @@ public class ModelTrainDetailApiController {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
}
// @Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
// @ApiResponses(
// value = {
// @ApiResponse(
// responseCode = "200",
// description = "조회 성공",
// content =
// @Content(
// mediaType = "application/json",
// schema = @Schema(implementation = TransferDetailDto.class))),
// @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
// @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
// })
// @GetMapping("/transfer/detail/{uuid}")
// public ApiResponseDto<TransferDetailDto> getTransferDetail(
// @Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
// @PathVariable
// UUID uuid) {
// return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
// }
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses(
@@ -305,25 +304,4 @@ public class ModelTrainDetailApiController {
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainFileInfo(uuid));
}
@Operation(summary = "모델관리 > 모델별 진행 상황", description = "모델관리 > 모델별 진행 상황 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelProgressStepDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/progress/{uuid}")
public ApiResponseDto<List<ModelProgressStepDto>> findModelTrainProgressInfo(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.findModelTrainProgressInfo(uuid));
}
}

View File

@@ -155,9 +155,7 @@ public class ModelTrainMngApiController {
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
}
@Operation(
summary = "모델학습 1단계/2단계 실행중인 것이 있는지 count",
description = "모델학습 1단계/2단계 실행중인 것이 있는지 count")
@Operation(summary = "모델학습 1단계 실행중인 것이 있는지 count", description = "모델학습 1단계 실행중인 것이 있는지 count")
@ApiResponses(
value = {
@ApiResponse(

View File

@@ -20,25 +20,4 @@ public class ModelConfigDto {
private Float testPercent;
private String memo;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferBasic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
private Long beforeConfigId;
private Long beforeModelId;
private Integer beforeEpochCount;
private Float beforeTrainPercent;
private Float beforeValidationPercent;
private Float beforeTestPercent;
private String beforeMemo;
}
}

View File

@@ -6,7 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration;
import java.time.ZonedDateTime;
@@ -35,7 +35,6 @@ public class ModelTrainDetailDto {
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String statusCd;
private String trainType;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -177,10 +176,9 @@ public class ModelTrainDetailDto {
@NoArgsConstructor
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.TransferBasic etcConfig;
private ModelConfigDto.Basic etcConfig;
private TransferHyperSummary modelTrainHyper;
private List<SelectTransferDataSet> modelTrainDataset;
// private List<SelectDataSet> beforeTrainDataset;
private List<SelectDataSet> modelTrainDataset;
}
@Getter

View File

@@ -47,8 +47,6 @@ public class ModelTrainMngDto {
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private Long beforeModelId;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
try {
@@ -251,7 +249,6 @@ public class ModelTrainMngDto {
private String memo;
private String userNm;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -315,16 +312,4 @@ public class ModelTrainMngDto {
return formatDuration(this.packingStrtDttm, this.packingEndDttm);
}
}
@Getter
@Builder
@AllArgsConstructor
public static class ModelProgressStepDto {
private int step;
private String status;
@JsonFormatDttm private ZonedDateTime startTime;
@JsonFormatDttm private ZonedDateTime endTime;
private boolean isError;
}
}

View File

@@ -1,8 +1,7 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
@@ -15,7 +14,6 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetric
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.util.ArrayList;
@@ -73,11 +71,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectTransferDataSet> dataSets = new ArrayList<>();
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
@@ -90,37 +88,12 @@ public class ModelTrainDetailService {
datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
} else {
dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
}
// DatasetReq beforeDatasetReq = new DatasetReq();
// List<Long> beforeDatasetIds = new ArrayList<>();
// List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
// Long beforeModelId = modelInfo.getBeforeModelId();
// if (beforeModelId != null) {
// Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
// List<MappingDataset> beforeDatasets =
// modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
// for (MappingDataset before : beforeDatasets) {
// beforeDatasetIds.add(before.getDatasetId());
// }
// beforeDatasetReq.setIds(beforeDatasetIds);
// beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
// if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
// beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
// } else {
// beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
// }
// }
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);
@@ -148,8 +121,4 @@ public class ModelTrainDetailService {
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelTrainDetailCoreService.findModelTrainProgressInfo(uuid);
}
}

View File

@@ -2,7 +2,6 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -94,7 +93,7 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
// 데이터셋 임시파일 생성
// 임시파일 생성
trainJobService.createTmpFile(modelUuid);
return modelUuid;
}
@@ -116,7 +115,7 @@ public class ModelTrainMngService {
* @return
*/
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals(ModelType.G1.getId())) {
if (req.getModelNo().equals("G1")) {
return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);

View File

@@ -34,6 +34,7 @@ public class HyperParamCoreService {
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq);
// user
@@ -103,7 +104,7 @@ public class HyperParamCoreService {
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository.getHyperParamByType(model).stream()
hyperParamRepository.getHyperparamByType(model).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
@@ -171,7 +172,7 @@ public class HyperParamCoreService {
} else {
entity.setCropSize("256,256");
}
entity.setCropSize(src.getCropSize());
// entity.setCropSize(src.getCropSize());
// Important
entity.setModelType(model); // 20250212 modeltype추가

View File

@@ -38,12 +38,10 @@ public class ModelTestMetricsJobCoreService {
return modelTestMetricsJobRepository.findModelTestFileNames(modelId);
}
@Transactional
public void updatePackingStart(Long modelId, ZonedDateTime now) {
modelTestMetricsJobRepository.updatePackingStart(modelId, now);
}
@Transactional
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState);
}

View File

@@ -14,7 +14,6 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository;
@@ -57,12 +56,6 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getModelDetailSummary(uuid);
}
/**
* 하이퍼 파리미터 요약정보
*
* @param uuid 모델마스터 uuid
* @return
*/
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}
@@ -109,13 +102,4 @@ public class ModelTrainDetailCoreService {
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelDetailRepository.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelDetailRepository.findModelTrainProgressInfo(uuid);
}
public Basic findByModelBeforeId(Long beforeModelId) {
ModelMasterEntity entity = modelDetailRepository.findByModelBeforeId(beforeModelId);
return entity.toDto();
}
}

View File

@@ -1,18 +1,14 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -56,16 +52,11 @@ public class ModelTrainJobCoreService {
/** 실행 시작 처리 */
@Transactional
public void markRunning(
Long jobId,
String containerName,
String logPath,
String lockedBy,
Integer totalEpoch,
String jobType) {
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
@@ -73,44 +64,32 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
job.setJobType(jobType);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/**
* 성공 처리
*
* @param jobId
* @param exitCode
*/
/** 성공 처리 */
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now());
}
/**
* 실패 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
/** 실패 처리 */
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
@@ -120,35 +99,13 @@ public class ModelTrainJobCoreService {
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
}
/**
* 중단됨 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markPaused(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
}
/** 취소 처리 */
@Transactional
public void markCanceled(Long jobId) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
@@ -159,29 +116,10 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
job.setCurrentEpoch(epoch);
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
}
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
}
/**
* 실행중인 학습이 있는지 조회
*
* @return
*/
public List<ModelTrainJobDto> findRunningJobs() {
List<ModelTrainJobEntity> entity = modelTrainJobRepository.findRunningJobs();
if (entity == null || entity.isEmpty()) {
return Collections.emptyList();
}
return entity.stream().map(ModelTrainJobEntity::toDto).toList();
}
}

View File

@@ -8,7 +8,6 @@ import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
@@ -24,7 +23,6 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.ArrayList;
@@ -90,7 +88,7 @@ public class ModelTrainMngCoreService {
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity =
hyperParamRepository.getHyperParamByType(modelType).stream()
hyperParamRepository.getHyperparamByType(modelType).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElse(null);
@@ -104,12 +102,6 @@ public class ModelTrainMngCoreService {
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
}
// 하이퍼 파라미터 사용 횟수 업데이트
hyperParamEntity.setTotalUseCnt(
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
// 최근 사용일시 업데이트
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
String modelVer =
String.join(
@@ -278,13 +270,6 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
public ModelConfigDto.TransferBasic findModelTransferConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelTransferConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/**
* 데이터셋 G1 목록
*
@@ -295,16 +280,6 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG1List(req);
}
/**
* 전이학습 데이터셋 G1 목록
*
* @param modelId 모델 Id
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
return datasetRepository.getDatasetTransferSelectG1List(modelId);
}
/**
* 데이터셋 G2, G3 목록
*
@@ -315,18 +290,6 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 전이학습 데이터셋 G2, G3 목록
*
* @param modelId 모델 Id
* @param modelNo G2, G3
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
}
/**
* 모델관리 조회
*
@@ -390,12 +353,7 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
@@ -410,12 +368,7 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now());
}
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
@@ -566,34 +519,4 @@ public class ModelTrainMngCoreService {
public Long findModelStep1InProgressCnt() {
return modelMngRepository.findModelStep1InProgressCnt();
}
/**
* train 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTrainPath(modelId);
}
/**
* validation 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
return modelDatasetMapRepository.findDatasetValPath(modelId);
}
/**
* test 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTestPath(modelId);
}
}

View File

@@ -192,10 +192,10 @@ public class ModelHyperParamEntity {
@Column(name = "save_best_rule", nullable = false, length = 10)
private String saveBestRule = "greater";
/** Default: 1 */
/** Default: 10 */
@NotNull
@Column(name = "val_interval", nullable = false)
private Integer valInterval = 1;
private Integer valInterval = 10;
/** Default: 400 */
@NotNull
@@ -303,6 +303,15 @@ public class ModelHyperParamEntity {
@Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt")
private Long m1UseCnt = 0L;
@Column(name = "m2_use_cnt")
private Long m2UseCnt = 0L;
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@Column(name = "model_type")
@Enumerated(EnumType.STRING)
private ModelType modelType;
@@ -310,17 +319,12 @@ public class ModelHyperParamEntity {
@Column(name = "default_param")
private Boolean isDefault = false;
@Column(name = "total_use_cnt")
private Integer totalUseCnt = 0;
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.modelType,
this.uuid,
this.hyperVer,
this.createdDttm,
this.lastUsedDttm,
this.totalUseCnt,
// -------------------------
// Important
// -------------------------

View File

@@ -140,7 +140,6 @@ public class ModelMasterEntity {
this.requestPath,
this.packingState,
this.packingStrtDttm,
this.packingEndDttm,
this.beforeModelId);
this.packingEndDttm);
}
}

View File

@@ -1,42 +0,0 @@
package com.kamco.cd.training.postgres.entity;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import java.time.OffsetDateTime;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_test_training_run")
public class ModelTestTrainingRunEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "tsr_id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@Column(name = "attempt_no")
private Integer attemptNo;
@Column(name = "job_id")
private Long jobId;
@Column(name = "epoch")
private Integer epoch;
@ColumnDefault("now()")
@Column(name = "created_dttm")
private OffsetDateTime createdDttm;
}

View File

@@ -83,9 +83,6 @@ public class ModelTrainJobEntity {
@Column(name = "current_epoch")
private Integer currentEpoch;
@Column(name = "job_type")
private String jobType;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,

View File

@@ -4,7 +4,6 @@ import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List;
import java.util.Optional;
@@ -18,10 +17,6 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId);
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(Long modelId, String modelNo);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);

View File

@@ -1,20 +1,14 @@
package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder;
@@ -148,103 +142,6 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum(),
// ===== before (join으로) =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum()))
.from(modelMasterEntity)
// ===== 현재 dataset join =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before 모델 join =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId))
.groupBy(
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
@@ -308,116 +205,6 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
// before join용
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> elseCnt =
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> selectedCnt = ModelType.G2.getId().equals(modelNo) ? wasteCnt : elseCnt;
// before도 동일 로직으로 cnt 계산
NumberExpression<Long> beforeWasteCnt =
beforeObj.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> beforeElseCnt =
new CaseBuilder()
.when(beforeObj.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> beforeSelectedCnt =
ModelType.G2.getId().equals(modelNo) ? beforeWasteCnt : beforeElseCnt;
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo, // modelNo 파라미터 사용 (req.getModelNo() 제거)
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
selectedCnt, // classCount 자리에 들어가는 cnt (Long)
// ===== before =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
beforeSelectedCnt))
.from(modelMasterEntity)
// ===== 현재 dataset =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before dataset =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId).and(builder))
// sum() 때문에 groupBy 필요
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) {
return queryFactory

View File

@@ -29,28 +29,9 @@ public interface HyperParamRepositoryCustom {
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
/**
* 하이퍼 파라미터 상세조회
*
* @param uuid
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
/**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
/**
* 하이퍼 파라미터 모델타입으로 조회
*
* @param modelType
* @return
*/
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
}

View File

@@ -9,6 +9,7 @@ import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId;
@@ -81,7 +82,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.uuid.eq(uuid))
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
.fetchOne());
}
@@ -90,12 +91,10 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.isFalse());
if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model));
}
builder.and(modelHyperParamEntity.delYn.isFalse());
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전
@@ -119,6 +118,13 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
}
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query =
queryFactory
.select(
@@ -130,7 +136,10 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.memo,
modelHyperParamEntity.totalUseCnt))
modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m2UseCnt,
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
.from(modelHyperParamEntity)
.where(builder);
@@ -155,11 +164,8 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
asc
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" ->
query.orderBy(
asc
? modelHyperParamEntity.totalUseCnt.asc()
: modelHyperParamEntity.totalUseCnt.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}
@@ -181,7 +187,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
@Override
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)

View File

@@ -5,6 +5,4 @@ import java.util.Optional;
public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
Optional<ModelConfigDto.TransferBasic> findModelTransferConfigByModelId(Long modelId);
}

View File

@@ -1,12 +1,8 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.kamco.cd.training.model.dto.ModelConfigDto.TransferBasic;
import com.kamco.cd.training.postgres.entity.QModelConfigEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional;
@@ -38,44 +34,4 @@ public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
.where(modelConfigEntity.model.id.eq(modelId))
.fetchOne());
}
@Override
public Optional<TransferBasic> findModelTransferConfigByModelId(Long modelId) {
QModelConfigEntity beforeConfig = new QModelConfigEntity("beforeConfig");
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
TransferBasic.class,
// ===== 현재 =====
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo,
// ===== before =====
beforeConfig.id,
beforeConfig.model.id,
beforeConfig.epochCount,
beforeConfig.trainPercent,
beforeConfig.validationPercent,
beforeConfig.testPercent,
beforeConfig.memo))
.from(modelConfigEntity)
.innerJoin(modelConfigEntity.model, modelMasterEntity)
// before 모델 조인
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeConfig)
.on(beforeConfig.model.id.eq(beforeMaster.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne());
}
}

View File

@@ -1,15 +1,8 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.util.List;
public interface ModelDatasetMappRepositoryCustom {
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId);
List<ModelTrainLinkDto> findDatasetValPath(Long modelId);
List<ModelTrainLinkDto> findDatasetTestPath(Long modelId);
}

View File

@@ -1,16 +1,8 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import lombok.RequiredArgsConstructor;
@@ -30,136 +22,4 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
.where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetObjEntity.targetClassCd,
datasetObjEntity.comparePath,
datasetObjEntity.targetPath,
datasetObjEntity.labelPath,
datasetObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetObjEntity)
.on(
datasetObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetValObjEntity.targetClassCd,
datasetValObjEntity.comparePath,
datasetValObjEntity.targetPath,
datasetValObjEntity.labelPath,
datasetValObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetValObjEntity)
.on(
datasetValObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetTestObjEntity.targetClassCd,
datasetTestObjEntity.comparePath,
datasetTestObjEntity.targetPath,
datasetTestObjEntity.labelPath,
datasetTestObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetTestObjEntity)
.on(
datasetTestObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
}

View File

@@ -9,7 +9,6 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List;
import java.util.Optional;
@@ -38,8 +37,4 @@ public interface ModelDetailRepositoryCustom {
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
ModelFileInfo getModelTrainFileInfo(UUID uuid);
List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid);
ModelMasterEntity findByModelBeforeId(Long beforeModelId);
}

View File

@@ -19,15 +19,12 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -60,13 +57,6 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
@Override
public DetailSummary getModelDetailSummary(UUID uuid) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
return queryFactory
.select(
Projections.constructor(
@@ -78,8 +68,7 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelMasterEntity.step1StrtDttm,
modelMasterEntity.step2EndDttm,
modelMasterEntity.statusCd,
modelMasterEntity.trainType,
beforeModelUuid))
modelMasterEntity.trainType))
.from(modelMasterEntity)
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
@@ -298,78 +287,4 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
ModelMasterEntity entity = findByModelByUUID(uuid);
if (entity == null) {
return List.of();
}
List<ModelProgressStepDto> steps = new ArrayList<>();
// 0단계 : 대기 상태
steps.add(
ModelProgressStepDto.builder()
.step(0)
.status(TrainStatusType.READY.getId())
.startTime(entity.getCreatedDttm())
.endTime(null)
.isError(false)
.build());
// 1단계 : Train/Validation 실행
boolean step1Active =
entity.getStep1StrtDttm() != null
&& !TrainStatusType.READY.getId().equals(entity.getStep1State());
if (step1Active) {
steps.add(
ModelProgressStepDto.builder()
.step(1)
.status(entity.getStep1State())
.startTime(entity.getStep1StrtDttm())
.endTime(entity.getStep1EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep1State()))
.build());
}
// 2단계 : Test 실행
boolean step2Done = entity.getStep2State() != null;
if (step2Done) {
steps.add(
ModelProgressStepDto.builder()
.step(2)
.status(entity.getStep2State())
.startTime(entity.getStep2StrtDttm())
.endTime(entity.getStep2EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep2State()))
.build());
}
// 3단계 : 패키징
boolean step3Done = entity.getPackingState() != null;
if (step3Done) {
steps.add(
ModelProgressStepDto.builder()
.step(3)
.status(entity.getPackingState())
.startTime(entity.getPackingStrtDttm())
.endTime(entity.getPackingEndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getPackingState()))
.build());
}
return steps;
}
@Override
public ModelMasterEntity findByModelBeforeId(Long beforeModelId) {
return queryFactory
.selectFrom(modelMasterEntity)
.where(modelMasterEntity.id.eq(beforeModelId))
.fetchOne();
}
}

View File

@@ -9,10 +9,8 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -39,13 +37,6 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
*/
@Override
public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
@@ -87,8 +78,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelMasterEntity.packingStrtDttm,
modelMasterEntity.packingEndDttm,
modelConfigEntity.memo,
memberEntity.name,
beforeModelUuid))
memberEntity.name))
.from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))
@@ -201,11 +191,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
return queryFactory
.select(modelMasterEntity.id.count())
.from(modelMasterEntity)
.where(
modelMasterEntity
.step1State
.eq(TrainStatusType.IN_PROGRESS.getId())
.or(modelMasterEntity.step2State.eq(TrainStatusType.IN_PROGRESS.getId())))
.where(modelMasterEntity.step1State.eq(TrainStatusType.IN_PROGRESS.getId()))
.fetchOne();
}
}

View File

@@ -65,48 +65,16 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
@Override
public void insertModelMetricsTest(List<Object[]> batchArgs) {
// AS-IS
// String sql =
// """
// insert into tb_model_metrics_test
// (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
// detection_count, gt_count
// )
// values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
// """;
//
// jdbcTemplate.batchUpdate(sql, batchArgs);
// TO-BE: modelId, model(best_fscore_10) 같은 데이터가 있으면 update, 없으면 insert
String updateSql =
String sql =
"""
UPDATE tb_model_metrics_test
SET tp=?, fp=?, fn=?, precisions=?, recall=?, f1_score=?, accuracy=?, iou=?,
detection_count=?, gt_count=?
WHERE model_id=? AND model=?
""";
insert into tb_model_metrics_test
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
detection_count, gt_count
)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
String insertSql =
"""
INSERT INTO tb_model_metrics_test
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
detection_count, gt_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
// row 단위 처리 (batch 안에서 upsert)
for (Object[] row : batchArgs) {
// row 순서: (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
// detection_count, gt_count)
int updated =
jdbcTemplate.update(
updateSql, row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10],
row[11], row[0], row[1]);
if (updated == 0) {
jdbcTemplate.update(insertSql, row);
}
}
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
@@ -132,8 +100,7 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
modelMetricsTestEntity.model.eq(modelMetricsTrainEntity.model),
modelMasterEntity.bestEpoch.eq(modelMetricsTrainEntity.epoch))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.orderBy(modelMetricsTestEntity.createdDttm.desc())
.fetchFirst();
.fetchOne();
}
@Override

View File

@@ -1,7 +1,6 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.List;
import java.util.Optional;
public interface ModelTrainJobRepositoryCustom {
@@ -10,8 +9,4 @@ public interface ModelTrainJobRepositoryCustom {
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
List<ModelTrainJobEntity> findRunningJobs();
}

View File

@@ -1,15 +1,9 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.JobType;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.List;
import java.util.Optional;
import org.springframework.stereotype.Repository;
@@ -25,7 +19,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override
public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = modelTrainJobEntity;
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
@@ -39,7 +33,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
*/
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = modelTrainJobEntity;
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
@@ -49,7 +43,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = modelTrainJobEntity;
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
@@ -60,40 +54,4 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
return Optional.ofNullable(job);
}
@Override
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
Integer maxAttemptNo =
queryFactory
.select(modelTestTrainingRunEntity.attemptNo.max().coalesce(0))
.from(modelTestTrainingRunEntity)
.where(modelTestTrainingRunEntity.modelId.eq(modelId))
.fetchOne();
int nextAttemptNo = (maxAttemptNo == null ? 1 : maxAttemptNo + 1);
queryFactory
.insert(modelTestTrainingRunEntity)
.columns(
modelTestTrainingRunEntity.modelId,
modelTestTrainingRunEntity.attemptNo,
modelTestTrainingRunEntity.jobId,
modelTestTrainingRunEntity.epoch)
.values(modelId, nextAttemptNo, jobId, epoch)
.execute();
}
@Override
public List<ModelTrainJobEntity> findRunningJobs() {
return queryFactory
.select(modelTrainJobEntity)
.from(modelTrainJobEntity)
.where(
modelTrainJobEntity
.statusCd
.eq(JobStatusType.RUNNING.getId())
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
.orderBy(modelTrainJobEntity.id.desc())
.fetch();
}
}

View File

@@ -185,7 +185,7 @@ public class TrainApiController {
})
@PostMapping("/create-tmp/{uuid}")
public ApiResponseDto<UUID> createTmpFile(
@Parameter(description = "model uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {

View File

@@ -1,24 +0,0 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class ModelTrainLinkDto {
private Long modelId;
private String trainType;
private String modelNo;
private Long datasetId;
private String targetClassCd;
private String comparePath;
private String targetPath;
private String labelPath;
private String geoJsonPath;
private String datasetUid;
}

View File

@@ -6,17 +6,13 @@ import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
/** 학습실행 파일 하드링크 */
@Service
@Log4j2
@RequiredArgsConstructor
@@ -46,10 +42,6 @@ public class DataSetCountersService {
// tmp
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
// 차이나는거
diffMergedRequestsVsTmp(uids, tmpPath);
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
allLogs
.append(counters2.prints(basic.getRequestPath(), "TMP"))
@@ -171,58 +163,4 @@ public class DataSetCountersService {
test + test2);
}
}
private Set<String> listTifRelative(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> root.relativize(p).toString().replace("\\", "/"))
.collect(Collectors.toSet());
}
}
private Set<String> listTifFileNameOnly(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> p.getFileName().toString()) // 파일명만
.collect(Collectors.toSet());
}
}
public void diffMergedRequestsVsTmp(List<String> uids, Path tmpRoot) throws IOException {
// 1) 요청 uids 전체를 합친 tif "파일명" 집합
Set<String> reqAll = new HashSet<>();
for (String uid : uids) {
Path reqRoot = Path.of(requestDir, uid);
// ★합본 tmp는 보통 폴더 구조가 바뀌므로 "상대경로" 비교보다 파일명 비교가 먼저 유용합니다.
reqAll.addAll(listTifFileNameOnly(reqRoot));
}
// 2) tmp tif 파일명 집합
Set<String> tmpAll = listTifFileNameOnly(tmpRoot);
Set<String> missing = new HashSet<>(reqAll);
missing.removeAll(tmpAll);
Set<String> extra = new HashSet<>(tmpAll);
extra.removeAll(reqAll);
log.info("==== MERGED DIFF (filename-based) ====");
log.info("request(all uids) tif = {}", reqAll.size());
log.info("tmp tif = {}", tmpAll.size());
log.info("missing = {}", missing.size());
log.info("extra = {}", extra.size());
missing.stream().sorted().limit(50).forEach(f -> log.warn("[MISSING] {}", f));
extra.stream().sorted().limit(50).forEach(f -> log.warn("[EXTRA] {}", f));
}
}

View File

@@ -7,9 +7,6 @@ import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
@@ -52,14 +49,7 @@ public class DockerTrainService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param req
* @param containerName
* @return
* @throws Exception
*/
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req);
@@ -274,7 +264,8 @@ public class DockerTrainService {
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds()); // null
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
@@ -348,15 +339,7 @@ public class DockerTrainService {
}
}
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param containerName
* @param req
* @return
* @throws Exception
*/
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);
@@ -429,9 +412,7 @@ public class DockerTrainService {
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
Path epochPath = Paths.get(responseDir, req.getOutputFolder());
// 결과 폴더에 파라미터로 받은 베스트 epoch이 best_changed_fscore_epoch_ 로 시작하는 파일이 있는지 확인 후 pth 파일명 반환
String modelFile = findCheckpoint(epochPath, epoch);
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
List<String> c = new ArrayList<>();
@@ -462,25 +443,4 @@ public class DockerTrainService {
return c;
}
public String findCheckpoint(Path dir, int epoch) {
String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch);
String normalFileName = String.format("epoch_%d.pth", epoch);
Path bestPath = dir.resolve(bestFileName);
Path normalPath = dir.resolve(normalFileName);
// 1. best 파일이 존재하면 그거 사용
if (Files.isRegularFile(bestPath)) {
return bestFileName;
}
// 2. 없으면 일반 epoch 파일 사용
if (Files.isRegularFile(normalPath)) {
return normalFileName;
}
throw new IllegalStateException("Checkpoint 파일이 없습니다. epoch=" + epoch);
}
}

View File

@@ -1,449 +0,0 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
/**
* 서버 재기동 시 "RUNNING 상태로 남아있는 학습 Job"을 복구(정리)하기 위한 서비스.
*
* <p>상황 예시: - 서버가 강제 재기동/장애로 내려감 - DB 상에서는 job_state가 RUNNING(진행중)으로 남아있음 - 실제 docker 컨테이너는: 1) 아직
* 살아있거나(running=true) 2) 종료되었거나(exited) 3) --rm 옵션으로 인해 컨테이너가 이미 삭제되어 존재하지 않을 수 있음
*
* <p>이 클래스는 ApplicationReadyEvent(스프링 부팅 완료) 시점에 실행되어, DB의 RUNNING 잡들을 조회한 뒤 컨테이너 상태를 점검하고,
* SUCCESS/FAILED 처리를 수행합니다.
*/
@Profile("!local")
@Component
@RequiredArgsConstructor
@Log4j2
public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
/**
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
*
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
*/
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
*
* <p>@Transactional: - recover() 메서드 전체가 하나의 트랜잭션으로 감싸집니다. - Job 하나씩 처리하다가 예외가 발생하면 전체 롤백이 될 수
* 있으므로 "잡 단위로 확실히 커밋"이 필요하면 (권장) 잡 단위로 분리 트랜잭션(REQUIRES_NEW) 고려하세요.
*/
// @EventListener(ApplicationReadyEvent.class)
@Transactional
public void recover() {
// 1) DB에서 "RUNNING(진행중) 상태"로 남아있는 job 목록을 조회
List<ModelTrainJobDto> runningJobs = modelTrainJobCoreService.findRunningJobs();
// 실행중 job이 없으면 할 일 없음
if (runningJobs == null || runningJobs.isEmpty()) {
return;
}
// 2) 각 job에 대해 docker 컨테이너 상태를 확인하고, 상태에 따라 조치
for (ModelTrainJobDto job : runningJobs) {
String containerName = job.getContainerName();
try {
// 2-1) docker inspect로 컨테이너 상태 조회
DockerInspectState state = inspectContainer(containerName);
// 3) 컨테이너가 "없음"
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
if (!state.exists()) {
log.warn(
"[RECOVERY] container missing. try file-based reconcile. container={}",
containerName);
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
OutputResult out = probeOutputs(job);
// 3-2) 산출물이 충분하면 성공 처리
if (out.completed()) {
log.info("[RECOVERY] outputs look completed. mark SUCCESS. jobId={}", job.getId());
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 3-3) 산출물이 부족하면 실패 처리(운영 정책에 따라 "유예"도 가능)
log.warn(
"[RECOVERY] outputs incomplete. mark FAILED. jobId={} reason={}",
job.getId(),
out.reason());
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
markStepErrorByJobType(job, out.reason());
}
continue;
}
// 4) 컨테이너는 존재하고, 아직 running=true
// - 서버만 재기동됐고 컨테이너는 그대로 살아있는 케이스
// - 이 경우 DB를 건드리면 오히려 꼬일 수 있으니 RUNNING 유지
if (state.running()) {
log.info("[RECOVERY] container still running. container={}", containerName);
try {
ProcessBuilder pb = new ProcessBuilder("docker", "stop", "-t", "20", containerName);
pb.redirectErrorStream(true);
Process p = pb.start();
boolean finished = p.waitFor(30, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker stop timeout");
}
int code = p.exitValue();
if (code != 0) {
throw new IOException("docker stop failed. exit=" + code);
}
log.info(
"[RECOVERY] container stopped (will be auto removed by --rm). container={}",
containerName);
// 여기서 상태를 PAUSED로 바꿔도 되고
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
} catch (Exception e) {
log.error("[RECOVERY] docker stop failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
}
continue;
}
// 5) 컨테이너는 존재하지만 running=false
// - exited / dead 등의 상태
Integer exitCode = state.exitCode();
String status = state.status();
// 5-1) exitCode=0이면 정상 종료로 간주 → SUCCESS 처리
if (exitCode != null && exitCode == 0) {
log.info("[RECOVERY] container exited(0). mark SUCCESS. container={}", containerName);
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 5-2) exitCode != 0 이거나 null이면 실패로 간주 → FAILED 처리
log.warn(
"[RECOVERY] container exited non-zero. mark FAILED. container={} status={} exitCode={}",
containerName,
status,
exitCode);
modelTrainJobCoreService.markFailed(
job.getId(), exitCode, "SERVER_RESTART_CONTAINER_EXIT_NONZERO");
markStepErrorByJobType(job, "exit=" + exitCode + " status=" + status);
}
} catch (Exception e) {
// 6) docker inspect 자체가 실패한 경우
// - docker 데몬 문제/권한 문제/일시적 오류 가능
// - 운영 정책에 따라 "바로 실패" 대신 "유예" 처리도 고려 가능
log.error("[RECOVERY] container inspect failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_INSPECT_ERROR");
markStepErrorByJobType(job, "inspect-error");
}
}
}
/**
* jobType에 따라 학습 관리 테이블의 "성공 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 성공 - 그 외 → step1(학습 단계) 성공
*/
private void markStepSuccessByJobType(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Success(job.getModelId());
} else {
modelTrainMngCoreService.markStep1Success(job.getModelId());
}
}
/**
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
*/
private void markStepErrorByJobType(ModelTrainJobDto job, String msg) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Error(job.getModelId(), msg);
} else {
modelTrainMngCoreService.markError(job.getModelId(), msg);
}
}
/**
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
*
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
*
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
*
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
* 타임아웃을 둠
*/
private DockerInspectState inspectContainer(String containerName)
throws IOException, InterruptedException {
ProcessBuilder pb =
new ProcessBuilder(
"docker",
"inspect",
"-f",
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
containerName);
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
pb.redirectErrorStream(true);
Process p = pb.start();
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
String output;
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
output = br.readLine();
}
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker inspect timeout");
}
// docker inspect 자체의 프로세스 exit code
int code = p.exitValue();
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
if (code != 0 || output == null || output.isBlank()) {
return DockerInspectState.missing();
}
// "status running exitCode" 형태로 split
String[] parts = output.trim().split("\\s+");
// status: running/exited/dead 등
String status = parts.length > 0 ? parts[0] : "unknown";
// running: true/false
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
// exitCode: 정수 파싱(파싱 실패하면 null)
Integer exitCode = null;
if (parts.length > 2) {
try {
exitCode = Integer.parseInt(parts[2]);
} catch (Exception ignore) {
// ignore
}
}
return new DockerInspectState(true, running, exitCode, status);
}
/**
* docker inspect 결과를 담는 레코드.
*
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
*/
private record DockerInspectState(
boolean exists, boolean running, Integer exitCode, String status) {
static DockerInspectState missing() {
return new DockerInspectState(false, false, null, "missing");
}
}
// ============================================================================================
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
*
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
*
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
* 실패로 오판할 수 있음.
*/
private OutputResult probeOutputs(ModelTrainJobDto job) {
try {
Path outDir = resolveOutputDir(job);
if (outDir == null || !Files.isDirectory(outDir)) {
return new OutputResult(false, "output-dir-missing");
}
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
if (totalEpoch == null || totalEpoch <= 0) {
return new OutputResult(false, "total-epoch-missing");
}
Path valCsv = outDir.resolve("val.csv");
if (!Files.exists(valCsv)) {
return new OutputResult(false, "val.csv-missing");
}
long lines = countNonHeaderLines(valCsv);
// “같아야 완료” 정책
if (lines == totalEpoch) {
return new OutputResult(true, "ok");
}
return new OutputResult(
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
} catch (Exception e) {
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
return new OutputResult(false, "probe-error");
}
}
/**
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
*
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
*
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
*
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
* 형태를 1순위로 두세요.
*/
private Path resolveOutputDir(ModelTrainJobDto job) {
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
return Files.isDirectory(base) ? base : null;
}
/**
* paramsJson에서 total_epoch 값을 추출합니다.
*
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
*
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
*/
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("total_epoch");
if (v == null) v = params.get("totalEpoch");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
/**
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
*
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
*
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
*/
private long countNonHeaderLines(Path csv) throws IOException {
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
}
}
/**
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
*
* <p>예: - "*.pth" - "best*.pth"
*/
private long countFilesByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
long cnt = 0;
for (Path p : ds) {
if (Files.isRegularFile(p)) cnt++;
}
return cnt;
}
}
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
private boolean existsByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
return ds.iterator().hasNext();
}
}
// ============================================================================================
// probeOutputs() 결과 객체
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
*
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
*
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
*/
private static final class OutputResult {
private final boolean completed;
private final String reason;
private OutputResult(boolean completed, String reason) {
this.completed = completed;
this.reason = reason;
}
boolean completed() {
return completed;
}
String reason() {
return reason;
}
}
}

View File

@@ -48,8 +48,20 @@ public class ModelTestMetricsJobService {
@Value("${file.pt-path}")
private String ptPathDir;
/** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() {
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() throws IOException {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();

View File

@@ -36,8 +36,20 @@ public class ModelTrainMetricsJobService {
@Value("${train.docker.responseDir}")
private String responseDir;
/** 결과 csv 파일 정보 등록 */
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTrainValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();

View File

@@ -23,14 +23,6 @@ public class TestJobService {
private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @param uuid
* @param epoch
* @return
*/
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
@@ -59,18 +51,13 @@ public class TestJobService {
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, params, ZonedDateTime.now());
// test training run 테이블에 적재하기
modelTrainJobCoreService.insertModelTestTrainingRun(modelId, jobId, epoch);
// step2 시작으로 마킹
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 취소
*
* @param modelId
*/
@Transactional
public void cancel(Long modelId) {

View File

@@ -1,6 +1,5 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.io.IOException;
import java.nio.file.*;
import java.util.List;
@@ -20,95 +19,6 @@ public class TmpDatasetService {
@Value("${train.docker.basePath}")
private String trainBaseDir;
/**
* train, val, test 폴더별로 link
*
* @param uid 임시폴더 uuid
* @param type train, val, test
* @param links tif pull path
* @return
* @throws IOException
*/
public void buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
throws IOException {
if (links == null || links.isEmpty()) {
throw new IOException("links is empty");
}
Path tmp = Path.of(trainBaseDir, "tmp", uid);
long hardlinksMade = 0;
for (ModelTrainLinkDto dto : links) {
if (type == null) {
log.warn("SKIP - trainType null: {}", dto);
continue;
}
// type별 디렉토리 생성
Files.createDirectories(tmp.resolve(type).resolve("input1"));
Files.createDirectories(tmp.resolve(type).resolve("input2"));
Files.createDirectories(tmp.resolve(type).resolve("label"));
Files.createDirectories(tmp.resolve(type).resolve("label-json"));
// comparePath → input1
hardlinksMade += link(tmp, type, "input1", dto.getComparePath());
// targetPath → input2
hardlinksMade += link(tmp, type, "input2", dto.getTargetPath());
// labelPath → label
hardlinksMade += link(tmp, type, "label", dto.getLabelPath());
// geoJsonPath -> label-json
hardlinksMade += link(tmp, type, "label-json", dto.getGeoJsonPath());
}
if (hardlinksMade == 0) {
throw new IOException("No hardlinks created.");
}
log.info("tmp dataset created: {}, hardlinksMade={}", tmp, hardlinksMade);
}
private long link(Path tmp, String type, String part, String fullPath) throws IOException {
if (fullPath == null || fullPath.isBlank()) return 0;
Path src = Path.of(fullPath);
if (!Files.isRegularFile(src)) {
log.warn("SKIP (not file): {}", src);
return 0;
}
String fileName = src.getFileName().toString();
Path dst = tmp.resolve(type).resolve(part).resolve(fileName);
// 충돌 시 덮어쓰기
if (Files.exists(dst)) {
Files.delete(dst);
}
Files.createLink(dst, src);
return 1;
}
private String safe(String s) {
return (s == null || s.isBlank()) ? null : s.trim();
}
/**
* request 전체 폴더 link
*
* @param uid
* @param datasetUids
* @return
* @throws IOException
*/
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
log.info("========== buildTmpDatasetHardlink START ==========");

View File

@@ -7,7 +7,6 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException;
import java.nio.file.Files;
@@ -47,12 +46,7 @@ public class TrainJobService {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @return
*/
/** 실행 예약 (QUEUE 등록) */
@Transactional
public Long enqueue(Long modelId) {
@@ -144,13 +138,6 @@ public class TrainJobService {
modelTrainMngCoreService.markStopped(modelId);
}
/**
* 학습 이어하기
*
* @param modelId 모델 id
* @param mode NONE 새로 시작, REQUIRE 이어하기
* @return
*/
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
@@ -201,12 +188,6 @@ public class TrainJobService {
REQUIRE // 이어하기
}
/**
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
*
* @param paramsJson
* @return
*/
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
@@ -248,43 +229,21 @@ public class TrainJobService {
}
}
/**
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
*
* @param modelUuid
* @return
*/
@Transactional
public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
// model id 가져오기
// MODELID 가져오기
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
// model 에 연결된 dataset id 가져오기
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try {
// 데이터셋 심볼링크 생성
// String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
// train path
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
// validation path
List<ModelTrainLinkDto> valList = modelTrainMngCoreService.findDatasetValPath(modelId);
// test path
List<ModelTrainLinkDto> testList = modelTrainMngCoreService.findDatasetTestPath(modelId);
// train 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "train", trainList);
// val 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "val", valList);
// test 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "test", testList);
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(raw);
updateReq.setRequestPath(pathUid);
// 학습모델을 수정한다.
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
@@ -298,8 +257,14 @@ public class TrainJobService {
e);
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
throw new CustomApiException(
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
throw new IllegalStateException(
"tmp dataset build failed: modelUuid="
+ modelUuid
+ ", modelId="
+ modelId
+ ", tmpRaw="
+ raw,
e);
}
return modelUuid;
}

View File

@@ -17,7 +17,6 @@ import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
/** job 실행 */
@Log4j2
@Component
@RequiredArgsConstructor
@@ -30,12 +29,10 @@ public class TrainJobWorker {
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async("trainJobExecutor")
@Async
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
Long jobId = event.getJobId();
ModelTrainJobDto job =
@@ -57,8 +54,6 @@ public class TrainJobWorker {
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
String type = isEval ? "TEST" : "TRAIN";
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
@@ -66,15 +61,12 @@ public class TrainJobWorker {
}
}
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
// 실행 시작 처리
modelTrainJobCoreService.markRunning(
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
log.info("[JOB] markRunning done jobId={}", jobId);
try {
TrainRunResult result;
if (isEval) {
// step2 진행중 처리
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
@@ -87,15 +79,12 @@ public class TrainJobWorker {
evalReq.setTimeoutSeconds(null);
evalReq.setDatasetFolder(datasetFolder);
evalReq.setOutputFolder(outputFolder);
log.info("[JOB] selected test epoch={}", epoch);
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
result = dockerTrainService.runEvalSync(evalReq, containerName);
} else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params);
// 도커 실행 후 로그 수집
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
@@ -109,31 +98,24 @@ public class TrainJobWorker {
}
if (result.getExitCode() == 0) {
// 성공 처리
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) {
// step2 완료처리
modelTrainMngCoreService.markStep2Success(modelId);
// 결과 csv 파일 정보 등록
modelTestMetricsJobService.findTestValidMetricCsvFiles();
} else {
modelTrainMngCoreService.markStep1Success(modelId);
// 결과 csv 파일 정보 등록
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
}
} else {
String failMsg = result.getStatus() + "\n" + result.getLogs();
// 실패 처리
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
}
@@ -142,10 +124,8 @@ public class TrainJobWorker {
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}