76 Commits

Author SHA1 Message Date
0c34ea7dcb hyperparam_with_modeltype 2026-02-12 18:48:14 +09:00
3547c28361 Merge pull request 'feat/training_260202' (#55) from feat/training_260202 into develop
Reviewed-on: #55
2026-02-12 16:56:23 +09:00
6c70bfed18 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202
# Conflicts:
#	src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java
2026-02-12 16:55:52 +09:00
95a75e63f4 임시폴더생성 api 추가 2026-02-12 16:55:10 +09:00
2a1dbee290 Merge pull request '모델학습 1단계 실행중인 것이 있는지 count API' (#54) from feat/training_260202 into develop
Reviewed-on: #54
2026-02-12 16:51:09 +09:00
384a321bf3 모델학습 1단계 실행중인 것이 있는지 count API 2026-02-12 16:50:40 +09:00
f4e97d389b Merge pull request 'file 확인 API 수정' (#53) from feat/training_260202 into develop
Reviewed-on: #53
2026-02-12 16:42:20 +09:00
590810ff0a file 확인 API 수정 2026-02-12 16:41:40 +09:00
a01c872982 Merge pull request 'feat/training_260202' (#52) from feat/training_260202 into develop
Reviewed-on: #52
2026-02-12 16:15:11 +09:00
905a245070 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 16:14:45 +09:00
860ce35a8f docker mount 경로 추가 2026-02-12 16:14:19 +09:00
7f3f5dca40 Merge pull request 'feat/training_260202' (#51) from feat/training_260202 into develop
Reviewed-on: #51
2026-02-12 16:13:19 +09:00
4a0a4e35ed 학습 실행 수정 2026-02-12 16:12:58 +09:00
ae055dca1e 모델등록 수정 2026-02-12 16:01:14 +09:00
26e8e1492f Merge pull request 'feat/training_260202' (#50) from feat/training_260202 into develop
Reviewed-on: #50
2026-02-12 15:52:09 +09:00
8fa722011c 모델등록 수정 2026-02-12 15:51:54 +09:00
17d47d6200 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 15:47:10 +09:00
e178f58fe2 chunk save log 추가 2026-02-12 15:47:06 +09:00
cd0cf5726d Merge pull request 'feat/training_260202' (#49) from feat/training_260202 into develop
Reviewed-on: #49
2026-02-12 15:44:11 +09:00
8e4bea53da 모델등록 수정 2026-02-12 15:43:52 +09:00
7a22d8ba73 containerName 생성 변경 2026-02-12 15:39:12 +09:00
2df4a7a80b csv 파일 읽는 경로 읽어서 수정, train은 epoch + 1 해서 저장 2026-02-12 15:24:30 +09:00
b451f697bc 모델 마스터 테이블 request,response 경로 추가 2026-02-12 14:59:35 +09:00
7e9c867f34 Merge pull request '모델 등록할 때 step1State를 READY로 업데이트' (#48) from feat/training_260202 into develop
Reviewed-on: #48
2026-02-12 14:35:52 +09:00
130e85f8a1 모델 등록할 때 step1State를 READY로 업데이트 2026-02-12 14:35:17 +09:00
9e713cb49d Merge pull request '업로드 로직 재수정' (#47) from feat/training_260202 into develop
Reviewed-on: #47
2026-02-12 14:21:57 +09:00
51dfa97900 업로드 로직 재수정 2026-02-12 14:21:08 +09:00
87c6b599b4 Merge pull request 'feat/training_260202' (#46) from feat/training_260202 into develop
Reviewed-on: #46
2026-02-12 12:10:04 +09:00
f50855a822 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 12:08:04 +09:00
8d416317a8 베스트 에폭 API, 2단계 실행 시 best epoch 업데이트 2026-02-12 12:07:44 +09:00
22aa071476 Merge pull request 'feat/training_260202' (#45) from feat/training_260202 into develop
Reviewed-on: #45
2026-02-12 12:06:04 +09:00
a83bd09f8f containerName 생성 변경 2026-02-12 12:05:30 +09:00
96035f864a containerName 생성 변경 2026-02-12 11:42:38 +09:00
fd7dfd7e7f containerName 생성 변경 2026-02-12 11:10:28 +09:00
190b93bee8 실행 오류 수정 2026-02-12 10:58:51 +09:00
c5f19cc961 실행 오류 수정 2026-02-12 10:58:32 +09:00
c56c0ca605 실행 오류 수정 2026-02-12 10:58:26 +09:00
c6e721aa37 실행 오류 수정 2026-02-12 10:58:12 +09:00
6572e17f00 실행 오류 수정 2026-02-12 10:51:15 +09:00
be6365807c Merge pull request '실행 오류 수정' (#43) from feat/training_260202 into develop
Reviewed-on: #43
2026-02-12 10:20:05 +09:00
d2fff7dfde 실행 오류 수정 2026-02-12 10:19:44 +09:00
f66bc22c95 Merge pull request '실행 오류 수정' (#42) from feat/training_260202 into develop
Reviewed-on: #42
2026-02-12 10:14:54 +09:00
3367d0e7be 실행 오류 수정 2026-02-12 10:14:32 +09:00
352ec6ccb0 Merge pull request 'feat/training_260202' (#41) from feat/training_260202 into develop
Reviewed-on: #41
2026-02-12 09:53:02 +09:00
6a989255a3 모델별 데이터셋 목록 - G2,G3 dataTypeName 추가 2026-02-12 09:52:24 +09:00
878b21573f 테스트 실행 추가 2026-02-11 22:00:35 +09:00
0602db1436 Merge pull request '테스트 실행 추가' (#40) from feat/training_260202 into develop
Reviewed-on: #40
2026-02-11 21:58:58 +09:00
2f8bd1f98c 테스트 실행 추가 2026-02-11 21:58:25 +09:00
75231ccbba Merge pull request '추론 실행 추가' (#39) from feat/training_260202 into develop
Reviewed-on: #39
2026-02-11 20:22:01 +09:00
1249a80da5 추론 실행 추가 2026-02-11 20:21:25 +09:00
00c78eb42f Merge pull request '성능정보 그래프 데이터 API 추가' (#38) from feat/training_260202 into develop
Reviewed-on: #38
2026-02-11 19:52:23 +09:00
35767adba1 성능정보 그래프 데이터 API 추가 2026-02-11 19:52:00 +09:00
47a2a159ef Merge pull request 'test metrics 스케줄 추가' (#37) from feat/training_260202 into develop
Reviewed-on: #37
2026-02-11 19:10:37 +09:00
95548223cd test metrics 스케줄 추가 2026-02-11 19:09:58 +09:00
2debdc5312 Merge pull request 'feat/training_260202' (#36) from feat/training_260202 into develop
Reviewed-on: #36
2026-02-11 18:51:01 +09:00
207cc47f1b 스케줄 주석 2026-02-11 18:50:43 +09:00
b6338bce8e 테이블 구조 변경 2026-02-11 18:49:59 +09:00
2cfa2adcf5 tb_model_master 컬럼 추가 2026-02-11 17:21:48 +09:00
d7e19abfc9 uploadRate 로직 수정 2026-02-11 17:06:02 +09:00
c843703ee7 Merge pull request 'file 가져오기 86 호출하는 거로 추가' (#35) from feat/training_260202 into develop
Reviewed-on: #35
2026-02-11 16:53:25 +09:00
133ea6b1ba file 가져오기 86 호출하는 거로 추가 2026-02-11 16:49:48 +09:00
0df977ae81 Merge pull request '업로드 로직 86으로 수행하기 수정' (#34) from feat/training_260202 into develop
Reviewed-on: #34
2026-02-11 16:33:03 +09:00
3e39006822 업로드 로직 86으로 수행하기 수정 2026-02-11 16:32:40 +09:00
3ec1a71406 Merge pull request '업로드 로직 수정' (#33) from feat/training_260202 into develop
Reviewed-on: #33
2026-02-11 15:53:21 +09:00
16009f1623 업로드 로직 수정 2026-02-11 15:52:57 +09:00
41911014c9 Merge pull request '업로드 로직 수정' (#32) from feat/training_260202 into develop
Reviewed-on: #32
2026-02-11 15:44:54 +09:00
8ea32ce675 업로드 로직 수정 2026-02-11 15:44:18 +09:00
a4ac80c787 Merge pull request '업로드 경로 수정' (#31) from feat/training_260202 into develop
Reviewed-on: #31
2026-02-11 15:11:02 +09:00
3a5d136d34 업로드 경로 수정 2026-02-11 15:10:37 +09:00
2f63b9ddcd Merge pull request 'feat/training_260202' (#30) from feat/training_260202 into develop
Reviewed-on: #30
2026-02-11 14:08:58 +09:00
92de48b55e 전이학습 상세 로직 수정 2026-02-11 14:08:21 +09:00
224ddae68b 전이학습 상세 수정 2026-02-11 14:05:15 +09:00
885b72a0c6 Merge pull request '모델별 데이터셋 목록 조회 수정' (#29) from feat/training_260202 into develop
Reviewed-on: #29
2026-02-11 12:29:08 +09:00
9ac00d37c5 모델별 데이터셋 목록 조회 수정 2026-02-11 12:28:38 +09:00
fbb5a34867 Merge pull request '업로드 경로 원복' (#28) from feat/training_260202 into develop
Reviewed-on: #28
2026-02-11 12:12:43 +09:00
e25fc01b25 업로드 경로 원복 2026-02-11 12:12:08 +09:00
73 changed files with 3822 additions and 180 deletions

View File

@@ -3,6 +3,7 @@ plugins {
id 'org.springframework.boot' version '3.5.7'
id 'io.spring.dependency-management' version '1.1.7'
id 'com.diffplug.spotless' version '6.25.0'
id 'idea'
}
group = 'com.kamco.cd'
@@ -21,11 +22,23 @@ configurations {
}
}
// QueryDSL 생성된 소스 디렉토리 정의
def generatedSourcesDir = file("$buildDir/generated/sources/annotationProcessor/java/main")
repositories {
mavenCentral()
maven { url "https://repo.osgeo.org/repository/release/" }
}
// Gradle이 생성된 소스를 컴파일 경로에 포함하도록 설정
sourceSets {
main {
java {
srcDirs += generatedSourcesDir
}
}
}
dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-web'
@@ -83,6 +96,23 @@ dependencies {
implementation 'io.hypersistence:hypersistence-utils-hibernate-63:3.7.0'
implementation 'org.reflections:reflections:0.10.2'
implementation 'com.jcraft:jsch:0.1.55'
implementation 'org.apache.commons:commons-csv:1.10.0'
}
// IntelliJ가 생성된 소스를 인식하도록 설정
idea {
module {
// 소스 디렉토리로 인식
sourceDirs += generatedSourcesDir
// Generated Sources Root로 마킹 (IntelliJ에서 특별 처리)
generatedSourceDirs += generatedSourcesDir
// 소스 및 Javadoc 다운로드
downloadJavadoc = true
downloadSources = true
}
}
configurations.configureEach {
@@ -93,6 +123,21 @@ tasks.named('test') {
useJUnitPlatform()
}
// 컴파일 전 생성된 소스 디렉토리 생성 보장
tasks.named('compileJava') {
doFirst {
generatedSourcesDir.mkdirs()
}
}
// 생성된 소스 정리 태스크
tasks.register('cleanGeneratedSources', Delete) {
delete generatedSourcesDir
}
tasks.named('clean') {
dependsOn 'cleanGeneratedSources'
}
bootJar {
archiveFileName = 'ROOT.jar'

View File

@@ -14,6 +14,7 @@ services:
- /mnt/nfs_share/images:/app/original-images
- /mnt/nfs_share/model_output:/app/model-outputs
- /mnt/nfs_share/train_dataset:/app/train-dataset
- /home/kcomu/data:/home/kcomu/data
networks:
- kamco-cds
restart: unless-stopped

View File

@@ -2,8 +2,10 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication
@EnableScheduling
public class KamcoTrainingApplication {

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.common.dto;
import com.kamco.cd.training.common.enums.ModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
@@ -14,6 +15,10 @@ public class HyperParam {
// -------------------------
// Important
// -------------------------
@Schema(description = "모델", example = "large")
private ModelType model; // backbone
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import java.util.Arrays;
import lombok.AllArgsConstructor;
import lombok.Getter;
@@ -15,6 +16,10 @@ public enum ModelType implements EnumType {
private String desc;
public static ModelType getValueData(String modelNo) {
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
}
@Override
public String getId() {
return name();

View File

@@ -2,6 +2,10 @@ package com.kamco.cd.training.common.utils;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.BufferedReader;
import java.io.File;
@@ -23,6 +27,7 @@ import java.util.Arrays;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -501,11 +506,15 @@ public class FIleChecker {
try {
File dir = new File(targetPath);
log.info("targetPath={}", targetPath);
log.info("absolute targetPath={}", dir.getAbsolutePath());
if (!dir.exists()) {
dir.mkdirs();
}
File dest = new File(dir, String.valueOf(chunkIndex));
log.info("real save path = {}", dest.getAbsolutePath());
log.info("chunkIndex={}, uploadSize={}", chunkIndex, mfile.getSize());
log.info("savedSize={}", dest.length());
@@ -517,6 +526,9 @@ public class FIleChecker {
log.info("after delete={}", dest.length());
mfile.transferTo(dest);
log.info("after transfer size={}", dest.length());
log.info("after transfer exists={}", dest.exists());
return true;
} catch (IOException e) {
log.error("chunk save error", e);
@@ -702,12 +714,17 @@ public class FIleChecker {
}
public static void unzip(String fileName, String destDirectory) throws IOException {
File destDir = new File(destDirectory);
if (!destDir.exists()) {
destDir.mkdirs(); // 대상 폴더가 없으면 생성
}
String zipFilePath = destDirectory + File.separator + fileName;
String zipFilePath = destDirectory + "/" + fileName;
// zip 이름으로 폴더 생성 (확장자 제거)
String folderName =
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
File destDir = new File(destDirectory, folderName);
if (!destDir.exists()) {
destDir.mkdirs();
}
try (ZipInputStream zis = new ZipInputStream(new FileInputStream(zipFilePath))) {
ZipEntry zipEntry = zis.getNextEntry();
@@ -755,4 +772,138 @@ public class FIleChecker {
return destFile;
}
public static void uploadTo86(Path localFile) {
String host = "192.168.2.86";
int port = 22;
String username = "kcomu";
String password = "Kamco2025!";
String remoteDir = "/home/kcomu/data/request";
Session session = null;
ChannelSftp channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(username, host, port);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
channel = (ChannelSftp) session.openChannel("sftp");
channel.connect(10_000);
// 목적지 디렉토리 이동
channel.cd(remoteDir);
// 업로드
channel.put(localFile.toString(), localFile.getFileName().toString());
} catch (Exception e) {
throw new RuntimeException("SFTP upload failed", e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static void unzipOn86Server(String zipPath, String targetDir) {
String host = "192.168.2.86";
String user = "kcomu";
String password = "Kamco2025!";
Session session = null;
ChannelExec channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(user, host, 22);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
String command = "unzip -o " + zipPath + " -d " + targetDir;
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setErrStream(System.err);
InputStream in = channel.getInputStream();
channel.connect();
// 출력 읽기(선택)
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
while (br.readLine() != null) {
// 필요하면 로그
}
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static List<String> execCommandAndReadLines(String command) {
List<String> result = new ArrayList<>();
String host = "192.168.2.86";
String user = "kcomu";
String password = "Kamco2025!";
Session session = null;
ChannelExec channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(user, host, 22);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setInputStream(null);
InputStream in = channel.getInputStream();
channel.connect();
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
String line;
while ((line = br.readLine()) != null) {
result.add(line);
}
}
return result;
} catch (Exception e) {
throw new RuntimeException("remote command failed : " + command, e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
}

View File

@@ -14,15 +14,11 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource;
import org.springframework.data.domain.Page;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
@@ -230,10 +226,15 @@ public class DatasetApiController {
throws Exception {
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
Path filePath = Paths.get(path);
return datasetService.getFilePathByFile(path);
}
Resource resource = new UrlResource(filePath.toUri());
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
@GetMapping("/files-to86")
public ResponseEntity<Resource> getFileTo86(
@RequestParam UUID uuid, @RequestParam String pathType) throws Exception {
return ResponseEntity.ok().contentType(MediaType.APPLICATION_OCTET_STREAM).body(resource);
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
return datasetService.getFilePathByFile(path);
}
}

View File

@@ -1,8 +1,10 @@
package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
@@ -228,6 +230,7 @@ public class DatasetDto {
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
@@ -236,12 +239,16 @@ public class DatasetDto {
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
private Long classCount;
private Integer buildingCount;
private Integer containerCount;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
public SelectDataSet(
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
@@ -254,15 +261,22 @@ public class DatasetDto {
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
}
public SelectDataSet(
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
@@ -271,8 +285,8 @@ public class DatasetDto {
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCount,
Integer containerCount) {
Integer buildingCnt,
Integer containerCnt) {
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
@@ -282,8 +296,8 @@ public class DatasetDto {
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCount = buildingCount;
this.containerCount = containerCount;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
}
public String getDataTypeName(String groupTitleCd) {

View File

@@ -21,6 +21,7 @@ import com.kamco.cd.training.dataset.dto.DatasetObjDto.SearchReq;
import com.kamco.cd.training.postgres.core.DatasetCoreService;
import jakarta.validation.Valid;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@@ -33,8 +34,13 @@ import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -158,6 +164,44 @@ public class DatasetService {
}
}
@Deprecated
@Transactional
public ResponseObj insertDatasetTo86(@Valid AddReq addReq) {
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
// 압축 해제
FIleChecker.unzipOn86Server(
addReq.getFilePath() + addReq.getFileName(),
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 해제한 폴더 읽어서 데이터 저장
List<Map<String, Object>> list =
getUnzipDatasetFilesTo86(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "train");
int idx = 0;
for (Map<String, Object> map : list) {
datasetUid =
this.insertTrainTestData(map, addReq, idx, datasetUid, "train"); // train 데이터 insert
idx++;
}
List<Map<String, Object>> testList =
getUnzipDatasetFilesTo86(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
int testIdx = 0;
for (Map<String, Object> test : testList) {
datasetUid =
this.insertTrainTestData(test, addReq, testIdx, datasetUid, "test"); // test 데이터 insert
testIdx++;
}
datasetCoreService.updateDatasetUploadStatus(datasetUid);
return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다.");
}
@Transactional
public ResponseObj insertDataset(@Valid AddReq addReq) {
@@ -179,6 +223,17 @@ public class DatasetService {
idx++;
}
List<Map<String, Object>> valList =
getUnzipDatasetFiles(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "val");
int valIdx = 0;
for (Map<String, Object> valid : valList) {
datasetUid =
this.insertTrainTestData(valid, addReq, valIdx, datasetUid, "val"); // val 데이터 insert
valIdx++;
}
List<Map<String, Object>> testList =
getUnzipDatasetFiles(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
@@ -285,6 +340,8 @@ public class DatasetService {
if (subDir.equals("train")) {
datasetCoreService.insertDatasetObj(objRegDto);
} else if (subDir.equals("val")) {
datasetCoreService.insertDatasetValObj(objRegDto);
} else {
datasetCoreService.insertDatasetTestObj(objRegDto);
}
@@ -356,4 +413,112 @@ public class DatasetService {
public String getFilePathByUUIDPathType(UUID uuid, String pathType) {
return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType);
}
@Deprecated
private List<Map<String, Object>> getUnzipDatasetFilesTo86(String unzipRootPath, String subDir) {
// String root = Paths.get(unzipRootPath)
// .resolve(subDir)
// .toString();
//
String root = normalizeLinuxPath(unzipRootPath + "/" + subDir);
Map<String, Map<String, Object>> grouped = new HashMap<>();
for (String dirName : LABEL_DIRS) {
String remoteDir = root + "/" + dirName;
// 1. 86 서버에서 해당 디렉토리의 파일 목록 조회
List<String> files = listFilesOn86Server(remoteDir);
if (files.isEmpty()) {
throw new IllegalStateException("폴더가 존재하지 않거나 파일이 없습니다 : " + remoteDir);
}
for (String fullPath : files) {
String fileName = Paths.get(fullPath).getFileName().toString();
String baseName = getBaseName(fileName);
Map<String, Object> data = grouped.computeIfAbsent(baseName, k -> new HashMap<>());
data.put("baseName", baseName);
if ("label-json".equals(dirName)) {
// 2. json 내용도 86 서버에서 읽어서 가져와야 함
String json = readRemoteFileAsString(fullPath);
data.put("label-json", parseJson(json));
data.put("geojson_path", fullPath);
} else {
data.put(dirName, fullPath);
}
}
}
return new ArrayList<>(grouped.values());
}
private List<String> listFilesOn86Server(String remoteDir) {
String command = "find " + escape(remoteDir) + " -maxdepth 1 -type f";
return FIleChecker.execCommandAndReadLines(command);
}
private String readRemoteFileAsString(String remoteFilePath) {
String command = "cat " + escape(remoteFilePath);
List<String> lines = FIleChecker.execCommandAndReadLines(command);
return String.join("\n", lines);
}
private JsonNode parseJson(String json) {
try {
ObjectMapper mapper = new ObjectMapper();
return mapper.readTree(json);
} catch (IOException e) {
throw new RuntimeException("JSON 파싱 실패", e);
}
}
private String escape(String path) {
return "'" + path.replace("'", "'\"'\"'") + "'";
}
private static String normalizeLinuxPath(String path) {
return path.replace("\\", "/");
}
public ResponseEntity<Resource> getFilePathByFile(String remoteFilePath) {
try {
Path path = Paths.get(remoteFilePath);
InputStream inputStream = Files.newInputStream(path);
InputStreamResource resource =
new InputStreamResource(inputStream) {
@Override
public long contentLength() {
return -1; // 알 수 없으면 -1
}
};
String fileName = Paths.get(remoteFilePath.replace("\\", "/")).getFileName().toString();
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.hyperparam;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
@@ -65,7 +66,7 @@ public class HyperParamApiController {
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 수정 불가", content = @Content),
@ApiResponse(responseCode = "422", description = "default는 삭제불가", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PutMapping("/{uuid}")
@@ -87,8 +88,9 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/list")
@GetMapping("{model}/list")
public ApiResponseDto<Page<List>> getHyperParam(
@PathVariable ModelType model,
@Parameter(
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
example = "CREATE_DATE")
@@ -98,7 +100,7 @@ public class HyperParamApiController {
LocalDate startDate,
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false)
LocalDate endDate,
@Parameter(description = "버전명", example = "HPs_0001") @RequestParam(required = false)
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
String hyperVer,
@Parameter(
description = "정렬",
@@ -124,7 +126,7 @@ public class HyperParamApiController {
searchReq.setSort(sort);
searchReq.setPage(page);
searchReq.setSize(size);
Page<List> list = hyperParamService.getHyperParamList(searchReq);
Page<List> list = hyperParamService.getHyperParamList(model, searchReq);
return ApiResponseDto.ok(list);
}
@@ -133,7 +135,7 @@ public class HyperParamApiController {
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 삭제 불가", content = @Content),
@ApiResponse(responseCode = "422", description = "default 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
})
@DeleteMapping("/{uuid}")
@@ -179,8 +181,11 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/init")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam() {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam());
@GetMapping("/init/{model}")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
@PathVariable ModelType model
) {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.hyperparam.dto;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
@@ -24,6 +25,7 @@ public class HyperParamDto {
@AllArgsConstructor
public static class Basic {
private ModelType model; // 20250212 modeltype추가
private UUID uuid;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createdDttm;
@@ -98,6 +100,8 @@ public class HyperParamDto {
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
private Boolean isDefault;
}
@Getter

View File

@@ -1,8 +1,10 @@
package com.kamco.cd.training.hyperparam.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -20,11 +22,12 @@ public class HyperParamService {
/**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return 목록
*/
public Page<List> getHyperParamList(HyperParamDto.SearchReq req) {
return hyperParamCoreService.findByHyperVerList(req);
public Page<List> getHyperParamList(ModelType model, SearchReq req) {
return hyperParamCoreService.findByHyperVerList(model, req);
}
/**
@@ -59,8 +62,8 @@ public class HyperParamService {
}
/** 하이퍼파라미터 최적화 설정값 조회 */
public HyperParamDto.Basic getInitHyperParam() {
return hyperParamCoreService.getInitHyperParam();
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
return hyperParamCoreService.getInitHyperParam(model);
}
/**

View File

@@ -3,6 +3,10 @@ package com.kamco.cd.training.model;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.service.ModelTrainDetailService;
@@ -132,4 +136,90 @@ public class ModelTrainDetailApiController {
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/metrics/train/{uuid}")
public ApiResponseDto<List<ModelTrainMetrics>> getModelTrainMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainMetricResult(uuid));
}
@Operation(
summary = "모델관리 > 모델 상세 > 성능 정보 (Validation)",
description = "모델 상세 > 성능 정보 (Validation) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/metrics/validation/{uuid}")
public ApiResponseDto<List<ModelValidationMetrics>> getModelValidationMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelValidationMetricResult(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/metrics/test/{uuid}")
public ApiResponseDto<List<ModelTestMetrics>> getModelTestMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTestMetricResult(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/best-epoch/{uuid}")
public ApiResponseDto<ModelBestEpoch> getModelTrainBestEpoch(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
}
}

View File

@@ -74,13 +74,12 @@ public class ModelTrainMngApiController {
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content)
@ApiResponse(responseCode = "409", description = "G1_000001 삭제 불가", content = @Content)
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteModelTrain(
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
@PathVariable
UUID uuid) {
@PathVariable UUID uuid) {
modelTrainMngService.deleteModelTrain(uuid);
return ApiResponseDto.ok(null);
}
@@ -92,9 +91,8 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping
public ApiResponseDto<String> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
modelTrainMngService.createModelTrain(req);
return ApiResponseDto.ok("ok");
public ApiResponseDto<UUID> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
return ApiResponseDto.ok(modelTrainMngService.createModelTrain(req));
}
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
@@ -150,4 +148,22 @@ public class ModelTrainMngApiController {
req.setDataType(selectType);
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
}
@Operation(summary = "모델학습 1단계 실행중인 것이 있는지 count", description = "모델학습 1단계 실행중인 것이 있는지 count")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/ing-training-cnt")
public ApiResponseDto<Long> findModelStep1InProgressCnt() {
return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt());
}
}

View File

@@ -93,6 +93,29 @@ public class ModelTrainDetailDto {
private Integer batchSize;
}
@Schema(name = "모델학습관리 전이 하이파라미터", description = "모델학습관리 전이 하이파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class TransferHyperSummary {
private UUID uuid;
private Long hyperParamId;
private String hyperVer;
private String backbone;
private String inputSize;
private String cropSize;
private Integer batchSize;
private UUID beforeUuid;
private Long beforeHyperParamId;
private String beforeHyperVer;
private String beforeBackbone;
private String beforeInputSize;
private String beforeCropSize;
private Integer beforeBatchSize;
}
@Schema(name = "선택한 데이터셋 목록", description = "선택한 데이터셋 목록")
@Getter
@Setter
@@ -154,7 +177,72 @@ public class ModelTrainDetailDto {
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private HyperSummary modelTrainHyper;
private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelTrainMetrics {
private Integer epoch;
private Long iteration;
private Double loss;
private Double lr;
private Float durationTime;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelValidationMetrics {
private Integer epoch;
private Float aAcc;
private Float mFscore;
private Float mPrecision;
private Float mRecall;
private Float mIou;
private Float mAcc;
private Float changedFscore;
private Float changedPrecision;
private Float changedRecall;
private Float unchangedFscore;
private Float unchangedPrecision;
private Float unchangedRecall;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelTestMetrics {
private String model;
private Long tp;
private Long fp;
private Long fn;
private Float precision;
private Float recall;
private Float f1Score;
private Float accuracy;
private Float iou;
private Long detectionCount;
private Long gtCount;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelBestEpoch {
private Integer epoch;
private Double loss;
private Float f1Score;
private Float precision;
private Float recall;
private Float iou;
private Float accuracy;
}
}

View File

@@ -40,6 +40,7 @@ public class ModelTrainMngDto {
private String statusCd;
private String trainType;
private String modelNo;
private Long currentAttemptId;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -137,6 +138,9 @@ public class ModelTrainMngDto {
@Schema(description = "학습타입 GENERAL(일반), TRANSFER(전이)", example = "GENERAL")
private String trainType;
@Schema(description = "전이학습일때 선택한 모델 id")
private Long beforeModelId;
@NotNull
@Schema(
description = "하이퍼 파라미터 선택 타입 OPTIMIZED(최적화 파라미터),EXISTING(기존 파라미터),NEW(신규 파라미터)",
@@ -151,6 +155,17 @@ public class ModelTrainMngDto {
ModelConfig modelConfig;
}
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class UpdateReq {
private String requestPath;
private String responsePath;
}
@Getter
@Setter
public static class TrainingDataset {

View File

@@ -6,7 +6,12 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
@@ -55,6 +60,12 @@ public class ModelTrainDetailService {
return modelTrainDetailCoreService.findByModelByUUID(uuid);
}
/**
* 전이학습 모델선택 정보
*
* @param uuid
* @return
*/
public TransferDetailDto getTransferDetail(UUID uuid) {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
@@ -62,7 +73,7 @@ public class ModelTrainDetailService {
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
// 하이파라미터 정보 조회
HyperSummary hyperSummary = modelTrainDetailCoreService.getByModelHyperParamSummary(uuid);
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
@@ -74,6 +85,7 @@ public class ModelTrainDetailService {
datasetIds.add(mappingDataset.getDatasetId());
}
datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
@@ -88,4 +100,20 @@ public class ModelTrainDetailService {
return transferDetailDto;
}
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainMetricResult(uuid);
}
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelValidationMetricResult(uuid);
}
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelTestMetricResult(uuid);
}
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
}
}

View File

@@ -2,6 +2,8 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
@@ -10,11 +12,14 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -26,6 +31,7 @@ public class ModelTrainMngService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final HyperParamCoreService hyperParamCoreService;
private final TmpDatasetService tmpDatasetService;
/**
* 모델학습 조회
@@ -54,10 +60,17 @@ public class ModelTrainMngService {
* @return
*/
@Transactional
public void createModelTrain(ModelTrainMngDto.AddReq req) {
public UUID createModelTrain(ModelTrainMngDto.AddReq req) {
HyperParam hyperParam = req.getHyperParam();
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
// 전이 학습은 모델 선택 필수
if (TrainType.TRANSFER.getId().equals(req.getTrainType())) {
if (req.getBeforeModelId() == null) {
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
}
}
// 하이파라미터 신규저장
if (HyperParamSelectType.NEW.getId().equals(req.getHyperParamType())) {
// 하이퍼파라미터 등록
@@ -66,7 +79,10 @@ public class ModelTrainMngService {
}
// 모델학습 테이블 저장
Long modelId = modelTrainMngCoreService.saveModel(req);
ModelTrainMngDto.Basic modelDto = modelTrainMngCoreService.saveModel(req);
Long modelId = modelDto.getId();
UUID modelUuid = modelDto.getUuid();
// 모델학습 데이터셋 저장
modelTrainMngCoreService.saveModelDataset(modelId, req);
@@ -77,6 +93,24 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
List<String> uids =
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);
}
return modelUuid;
}
/**
@@ -102,4 +136,8 @@ public class ModelTrainMngService {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
}
}
public Long findModelStep1InProgressCnt() {
return modelTrainMngCoreService.findModelStep1InProgressCnt();
}
}

View File

@@ -0,0 +1,66 @@
package com.kamco.cd.training.model.service;
import java.io.IOException;
import java.nio.file.*;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Slf4j
@Service
@RequiredArgsConstructor
public class TmpDatasetService {
@Value("${train.docker.requestDir}")
private String requestDir;
@Transactional(readOnly = true)
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
// 환경에 맞게 yml로 빼는 걸 추천
Path BASE = Paths.get(requestDir);
Path tmp = BASE.resolve("tmp").resolve(uid);
// mkdir -p "$TMP"/train/{input1,input2,label} ...
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Files.createDirectories(tmp.resolve(type).resolve(part));
}
}
for (String id : datasetUids) {
Path srcRoot = BASE.resolve(id);
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
if (!Files.isDirectory(srcDir)) continue;
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
for (Path f : stream) {
if (!Files.isRegularFile(f)) continue;
String dstName = id + "__" + f.getFileName();
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
if (Files.exists(dst)) continue;
// ln -s "$f" "$dst" 와 동일
Files.createSymbolicLink(dst, f.toAbsolutePath());
}
}
}
}
}
log.info("tmp dataset created: {}", tmp);
return tmp;
}
}

View File

@@ -242,4 +242,8 @@ public class DatasetCoreService
entity.setStatus(LearnDataRegister.COMPLETED.getId());
}
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
datasetObjRepository.insertDatasetValObj(objRegDto);
}
}

View File

@@ -1,10 +1,12 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
import java.time.ZonedDateTime;
@@ -17,6 +19,7 @@ import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
public class HyperParamCoreService {
private final HyperParamRepository hyperParamRepository;
private final UserUtil userUtil;
@@ -27,7 +30,7 @@ public class HyperParamCoreService {
* @return 등록된 버전명
*/
public Basic createHyperParam(HyperParam createReq) {
String firstVersion = getFirstHyperParamVersion();
String firstVersion = getFirstHyperParamVersion(createReq.getModel());
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion);
@@ -47,17 +50,17 @@ public class HyperParamCoreService {
/**
* 하이퍼파라미터 수정
*
* @param uuid uuid
* @param uuid uuid
* @param createReq 등록 요청
* @return ver
*/
public String updateHyperParam(UUID uuid, HyperParam createReq) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
}
applyHyperParam(entity, createReq);
@@ -69,11 +72,112 @@ public class HyperParamCoreService {
return entity.getHyperVer();
}
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// }
//디폴트면 삭제불가
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository
.getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
return hyperParamRepository.findByHyperVerList(model, req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @param model 모델 타입
* @return ver
*/
public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository
.findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model))
.orElse(model.name() + "_000001");
}
/**
* 하이퍼 파라미터의 버전을 증가시킨다.
*
* @param hyperVer 현재 버전
* @param modelType 모델 타입
* @return 증가된 버전
*/
private String increase(String hyperVer, ModelType modelType) {
String prefix = modelType.name() + "_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%06d", num + 1);
}
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
ModelType model = src.getModel();
// 하드코딩 모델별로 다른경우 250212 bbn 하드코딩
if (model == ModelType.G3) {
entity.setCropSize("512,512");
} else {
entity.setCropSize("256,256");
}
// entity.setCropSize(src.getCropSize());
// Important
entity.setModelType(model); // 20250212 modeltype추가
entity.setBackbone(src.getBackbone());
entity.setInputSize(src.getInputSize());
entity.setCropSize(src.getCropSize());
entity.setBatchSize(src.getBatchSize());
// Data
@@ -111,78 +215,4 @@ public class HyperParamCoreService {
entity.setMemo(src.getMemo());
}
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam() {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByHyperVer("HPs_0001")
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
return hyperParamRepository.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @return ver
*/
public String getFirstHyperParamVersion() {
return hyperParamRepository
.findHyperParamVer()
.map(ModelHyperParamEntity::getHyperVer)
.map(this::increase)
.orElse("HPs_0001");
}
private String increase(String hyperVer) {
String prefix = "HPs_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%04d", num + 1);
}
}

View File

@@ -0,0 +1,29 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTestMetricsJobCoreService {
private final ModelTestMetricsJobRepository modelTestMetricsJobRepository;
@Transactional
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
}
// Test 로직 시작
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
}
public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
}
}

View File

@@ -7,6 +7,11 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
@@ -54,6 +59,10 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}
public TransferHyperSummary getTransferHyperSummary(UUID uuid) {
return modelDetailRepository.getByModelTransferHyperParamSummary(uuid);
}
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
return modelDetailRepository.getByModelMappingDataset(uuid);
}
@@ -72,4 +81,20 @@ public class ModelTrainDetailCoreService {
public ModelConfigDto.Basic findModelConfig(Long modelId) {
return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null);
}
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
return modelDetailRepository.getModelTrainMetricResult(uuid);
}
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
return modelDetailRepository.getModelValidationMetricResult(uuid);
}
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
return modelDetailRepository.getModelTestMetricResult(uuid);
}
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelDetailRepository.getModelTrainBestEpoch(uuid);
}
}

View File

@@ -0,0 +1,107 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class ModelTrainJobCoreService {
private final ModelTrainJobRepository modelTrainJobRepository;
public int findMaxAttemptNo(Long modelId) {
return modelTrainJobRepository.findMaxAttemptNo(modelId);
}
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
}
public Optional<ModelTrainJobDto> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
}
/** QUEUED Job 생성 */
@Transactional
public Long createQueuedJob(
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
ModelTrainJobEntity job = new ModelTrainJobEntity();
job.setModelId(modelId);
job.setAttemptNo(attemptNo);
job.setStatusCd("QUEUED");
job.setParamsJson(paramsJson);
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
modelTrainJobRepository.save(job);
return job.getId();
}
/** 실행 시작 처리 */
@Transactional
public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
job.setLogPath(logPath);
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/** 성공 처리 */
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 실패 처리 */
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
}
/** 취소 처리 */
@Transactional
public void markCanceled(Long jobId) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
}
}

View File

@@ -0,0 +1,32 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTrainMetricsJobCoreService {
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
}
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
}
@Transactional
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
modelTrainMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
}
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
}
}

View File

@@ -23,17 +23,22 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTrainMngCoreService {
private final ModelMngRepository modelMngRepository;
private final ModelDatasetRepository modelDatasetRepository;
private final ModelDatasetMappRepository modelDatasetMapRepository;
@@ -60,9 +65,9 @@ public class ModelTrainMngCoreService {
*/
public void deleteModel(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
entity.setDelYn(true);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
@@ -74,17 +79,19 @@ public class ModelTrainMngCoreService {
* @param addReq
* @return
*/
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
public ModelTrainMngDto.Basic saveModel(ModelTrainMngDto.AddReq addReq) {
ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
// 최적화 파라미터는 HPs_0001 사용
// 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else {
hyperParamEntity =
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
}
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
@@ -92,12 +99,13 @@ public class ModelTrainMngCoreService {
}
String modelVer =
String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
entity.setModelVer(modelVer);
entity.setHyperParamId(hyperParamEntity.getId());
entity.setModelNo(addReq.getModelNo());
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
entity.setBeforeModelId(addReq.getBeforeModelId());
if (addReq.getIsStart()) {
entity.setModelStep((short) 1);
@@ -107,18 +115,24 @@ public class ModelTrainMngCoreService {
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
} else {
entity.setStatusCd(TrainStatusType.READY.getId());
entity.setStep1State(TrainStatusType.READY.getId());
}
entity.setCreatedUid(userUtil.getId());
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
return resultEntity.getId();
ModelTrainMngDto.Basic result = new ModelTrainMngDto.Basic();
result.setId(resultEntity.getId());
result.setUuid(resultEntity.getUuid());
return result;
}
/**
* data set 저장
*
* @param modelId 저장한 모델 학습 id
* @param addReq 요청 파라미터
* @param addReq 요청 파라미터
*/
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
TrainingDataset dataset = addReq.getTrainingDataset();
@@ -143,10 +157,30 @@ public class ModelTrainMngCoreService {
modelDatasetRepository.save(datasetEntity);
}
/**
* 학습모델 수정
*
* @param modelId
* @param req
*/
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath());
}
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
entity.setRequestPath(req.getResponsePath());
}
}
/**
* 모델 데이터셋 mapping 테이블 저장
*
* @param modelId 모델학습 id
* @param modelId 모델학습 id
* @param datasetList 선택한 data set
*/
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
@@ -163,7 +197,7 @@ public class ModelTrainMngCoreService {
* 모델학습 config 저장
*
* @param modelId 모델학습 id
* @param req 요청 파라미터
* @param req 요청 파라미터
* @return
*/
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
@@ -183,7 +217,7 @@ public class ModelTrainMngCoreService {
/**
* 데이터셋 매핑 생성
*
* @param modelUid 모델 UID
* @param modelUid 모델 UID
* @param datasetIds 데이터셋 ID 목록
*/
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
@@ -205,13 +239,27 @@ public class ModelTrainMngCoreService {
public ModelMasterEntity findByUuid(UUID uuid) {
try {
return modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} catch (IllegalArgumentException e) {
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
}
}
/**
* uuid로 model id 조회
*
* @param uuid
* @return
*/
public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId();
}
/**
* 모델학습 아이디로 config정보 조회
*
@@ -221,8 +269,8 @@ public class ModelTrainMngCoreService {
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/**
@@ -244,4 +292,243 @@ public class ModelTrainMngCoreService {
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 모델관리 조회
*
* @param id
* @return
*/
public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity =
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto();
}
/**
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
*/
@Transactional
public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId);
// 필요하면 시작시간도 여기서 찍어줌
}
/**
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
*/
@Transactional
public void clearLastError(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null);
}
/**
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
*/
@Transactional
public void markStopped(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId());
}
/**
* 완료 처리(옵션) - Worker가 성공 시 호출
*/
@Transactional
public void markCompleted(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
@Transactional
public void markSuccess(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId());
// (선택) 마지막 에러 메시지 비우기
master.setLastError(null);
}
/**
* 학습 실행에 필요한 파라미터 조회
*
* @param modelId
* @return
*/
public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId);
}
/**
* step1 진행중 처리
*
* @param modelId
* @param jobId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step2 진행중 처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now());
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step1 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId());
entity.setStep1EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step2 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId());
entity.setStep2EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setBestEpoch(epoch);
}
/**
* 데이터셋 uid 조회
*
* @param datasetIds
* @return
*/
public List<String> findDatasetUid(List<Long> datasetIds) {
return datasetRepository.findDatasetUid(datasetIds);
}
public List<Long> findModelDatasetMapp(Long modelId) {
List<Long> datasetUids = new ArrayList<>();
List<ModelDatasetMappEntity> entities = modelDatasetMapRepository.findByModelUid(modelId);
for (ModelDatasetMappEntity entity : entities) {
datasetUids.add(entity.getDatasetUid());
}
return datasetUids;
}
public Long findModelStep1InProgressCnt() {
return modelMngRepository.findModelStep1InProgressCnt();
}
}

View File

@@ -0,0 +1,117 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.dataset.dto.DatasetObjDto.Basic;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
import org.locationtech.jts.geom.Geometry;
@Getter
@Setter
@Entity
@Table(name = "tb_dataset_val_obj")
public class DatasetValObjEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "obj_id", nullable = false)
private Long objId;
@NotNull
@Column(name = "dataset_uid", nullable = false)
private Long datasetUid;
@Column(name = "target_yyyy")
private Integer targetYyyy;
@Size(max = 255)
@Column(name = "target_class_cd")
private String targetClassCd;
@Column(name = "compare_yyyy")
private Integer compareYyyy;
@Size(max = 255)
@Column(name = "compare_class_cd")
private String compareClassCd;
@Size(max = 255)
@Column(name = "target_path")
private String targetPath;
@Size(max = 255)
@Column(name = "compare_path")
private String comparePath;
@Size(max = 255)
@Column(name = "label_path")
private String labelPath;
@Size(max = 255)
@Column(name = "geojson_path")
private String geojsonPath;
@Size(max = 255)
@Column(name = "map_sheet_num")
private String mapSheetNum;
@ColumnDefault("now()")
@Column(name = "created_dttm")
private ZonedDateTime createdDttm;
@Column(name = "created_uid")
private Long createdUid;
@ColumnDefault("false")
@Column(name = "deleted")
private Boolean deleted;
@Column(name = "uuid")
private UUID uuid;
@Size(max = 32)
@Column(name = "uid")
private String uid;
@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "geo_jsonb", columnDefinition = "jsonb")
private String geoJsonb;
@Column(name = "file_name")
private String fileName;
@Column(name = "geom", columnDefinition = "geometry")
private Geometry geom;
public Basic toDto() {
return new Basic(
this.objId,
this.datasetUid,
this.targetYyyy,
this.targetClassCd,
this.compareYyyy,
this.compareClassCd,
this.targetPath,
this.comparePath,
this.labelPath,
this.geojsonPath,
this.mapSheetNum,
this.createdDttm,
this.createdUid,
this.deleted,
this.uuid,
this.geoJsonb);
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import jakarta.persistence.*;
import jakarta.validation.constraints.NotNull;
@@ -311,8 +312,17 @@ public class ModelHyperParamEntity {
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@Column(name = "model_type")
@Enumerated(EnumType.STRING)
private ModelType modelType;
@Column(name = "default_param")
private Boolean isDefault = false;
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.modelType,
this.uuid,
this.hyperVer,
this.createdDttm,
@@ -385,6 +395,8 @@ public class ModelHyperParamEntity {
// -------------------------
this.gpuCnt,
this.gpuIds,
this.masterPort);
this.masterPort
, this.isDefault
);
}
}

View File

@@ -88,6 +88,30 @@ public class ModelMasterEntity {
@Column(name = "train_type")
private String trainType;
@Column(name = "before_model_id")
private Long beforeModelId;
@Column(name = "step1_metric_save_yn")
private Boolean step1MetricSaveYn;
@Column(name = "step2_metric_save_yn")
private Boolean step2MetricSaveYn;
@Column(name = "current_attempt_id")
private Long currentAttemptId;
@Column(name = "last_error")
private String lastError;
@Column(name = "best_epoch")
private Integer bestEpoch;
@Column(name = "request_path")
private String requestPath;
@Column(name = "response_path")
private String responsePath;
public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic(
this.id,
@@ -102,6 +126,7 @@ public class ModelMasterEntity {
this.step2State,
this.statusCd,
this.trainType,
this.modelNo);
this.modelNo,
this.currentAttemptId);
}
}

View File

@@ -19,8 +19,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_test")
public class ModelMatricsTestEntity {
@Table(name = "tb_model_metrics_test")
public class ModelMetricsTestEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@@ -45,9 +45,6 @@ public class ModelMatricsTestEntity {
@Column(name = "fn")
private Long fn;
@Column(name = "tn")
private Long tn;
@Column(name = "precisions")
private Float precisions;
@@ -63,8 +60,11 @@ public class ModelMatricsTestEntity {
@Column(name = "iou")
private Float iou;
@Column(name = "processed_images")
private Long processedImages;
@Column(name = "detection_count")
private Long detectionCount;
@Column(name = "gt_count")
private Long gtCount;
@ColumnDefault("now()")
@Column(name = "created_dttm")

View File

@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_train")
public class ModelMatricsTrainEntity {
@Table(name = "tb_model_metrics_train")
public class ModelMetricsTrainEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)

View File

@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_validation")
public class ModelMatricsValidationEntity {
@Table(name = "tb_model_metrics_validation")
public class ModelMetricsValidationEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)

View File

@@ -0,0 +1,103 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
@Getter
@Setter
@Entity
@Table(name = "tb_model_train_job")
public class ModelTrainJobEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@NotNull
@Column(name = "attempt_no", nullable = false)
private Integer attemptNo;
@Size(max = 30)
@NotNull
@Column(name = "status_cd", nullable = false, length = 30)
private String statusCd;
@NotNull
@Column(name = "params_json", nullable = false)
@JdbcTypeCode(SqlTypes.JSON)
private Map<String, Object> paramsJson;
@Size(max = 200)
@Column(name = "container_name", length = 200)
private String containerName;
@Size(max = 500)
@Column(name = "log_path", length = 500)
private String logPath;
@Column(name = "exit_code")
private Integer exitCode;
@Size(max = 2000)
@Column(name = "error_message", length = 2000)
private String errorMessage;
@ColumnDefault("now()")
@Column(name = "queued_dttm")
private ZonedDateTime queuedDttm;
@Column(name = "started_dttm")
private ZonedDateTime startedDttm;
@Column(name = "finished_dttm")
private ZonedDateTime finishedDttm;
@Column(name = "locked_dttm")
private ZonedDateTime lockedDttm;
@Size(max = 100)
@Column(name = "locked_by", length = 100)
private String lockedBy;
@Column(name = "total_epoch")
private Integer totalEpoch;
@Column(name = "current_epoch")
private Integer currentEpoch;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,
this.modelId,
this.attemptNo,
this.statusCd,
this.exitCode,
this.errorMessage,
this.containerName,
this.paramsJson,
this.queuedDttm,
this.startedDttm,
this.finishedDttm,
this.totalEpoch,
this.currentEpoch);
}
}

View File

@@ -22,4 +22,6 @@ public interface DatasetObjRepositoryCustom {
String getFilePathByUUIDPathType(UUID uuid, String pathType);
void insertDatasetTestObj(DatasetObjRegDto objRegDto);
void insertDatasetValObj(DatasetObjRegDto objRegDto);
}

View File

@@ -97,6 +97,49 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom {
}
}
@Override
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
ObjectMapper objectMapper = new ObjectMapper();
String json;
String geometryJson;
try {
json = objectMapper.writeValueAsString(objRegDto.getGeojson());
geometryJson =
objectMapper.writeValueAsString(
objRegDto.getGeojson().path("features").get(0).path("geometry"));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
try {
em.createNativeQuery(
"""
insert into tb_dataset_val_obj
(dataset_uid, target_yyyy, target_class_cd,
compare_yyyy, compare_class_cd,
target_path, compare_path, label_path, geo_jsonb, map_sheet_num, file_name, geom, geojson_path)
values
(?, ?, ?, ?, ?, ?, ?, ?, cast(? as jsonb), ?, ?, ST_SetSRID(ST_GeomFromGeoJSON(?), 5186), ?)
""")
.setParameter(1, objRegDto.getDatasetUid())
.setParameter(2, objRegDto.getTargetYyyy())
.setParameter(3, objRegDto.getTargetClassCd())
.setParameter(4, objRegDto.getCompareYyyy())
.setParameter(5, objRegDto.getCompareClassCd())
.setParameter(6, objRegDto.getTargetPath())
.setParameter(7, objRegDto.getComparePath())
.setParameter(8, objRegDto.getLabelPath())
.setParameter(9, json)
.setParameter(10, objRegDto.getMapSheetNum())
.setParameter(11, objRegDto.getFileName())
.setParameter(12, geometryJson)
.setParameter(13, objRegDto.getGeojsonPath())
.executeUpdate();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Page<DatasetObjEntity> searchDatasetObjectList(SearchReq searchReq) {
Pageable pageable = searchReq.toPageable();

View File

@@ -22,4 +22,6 @@ public interface DatasetRepositoryCustom {
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
List<String> findDatasetUid(List<Long> datasetIds);
}

View File

@@ -12,6 +12,7 @@ import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
@@ -103,6 +104,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.select(
Projections.constructor(
SelectDataSet.class,
Expressions.constant(req.getModelNo()),
dataset.id,
dataset.uuid,
dataset.dataType,
@@ -174,6 +176,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.select(
Projections.constructor(
SelectDataSet.class,
Expressions.constant(req.getModelNo()),
dataset.id,
dataset.uuid,
dataset.dataType,
@@ -239,4 +242,9 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.where(dataset.uid.eq(mngRegDto.getUid()))
.fetchOne();
}
@Override
public List<String> findDatasetUid(List<Long> datasetIds) {
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
}
}

View File

@@ -1,6 +1,8 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.Optional;
import java.util.UUID;
@@ -13,11 +15,22 @@ public interface HyperParamRepositoryCustom {
*
* @return
*/
@Deprecated
Optional<ModelHyperParamEntity> findHyperParamVer();
/**
* 모델 타입별 마지막 버전 조회
*
* @param modelType 모델 타입
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType);
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
}

View File

@@ -2,8 +2,10 @@ package com.kamco.cd.training.postgres.repository.hyperparam;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
@@ -41,6 +43,23 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1)
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
@@ -68,10 +87,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
@Override
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.modelType.eq(model));
builder.and(modelHyperParamEntity.delYn.isFalse());
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
@@ -161,4 +181,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
return new PageImpl<>(content, pageable, totalCount);
}
@Override
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne());
}
}

View File

@@ -6,4 +6,5 @@ import org.springframework.stereotype.Repository;
@Repository
public interface ModelDatasetMappRepository
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId> {}
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId>,
ModelDatasetMappRepositoryCustom {}

View File

@@ -0,0 +1,8 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import java.util.List;
public interface ModelDatasetMappRepositoryCustom {
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
}

View File

@@ -0,0 +1,25 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositoryCustom {
private final JPAQueryFactory queryFactory;
@Override
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
queryFactory
.select(modelDatasetMappEntity)
.from(modelDatasetMappEntity)
.where(modelDatasetMappEntity.modelUid.eq(modelId));
return List.of();
}
}

View File

@@ -3,6 +3,11 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List;
import java.util.Optional;
@@ -16,7 +21,17 @@ public interface ModelDetailRepositoryCustom {
HyperSummary getByModelHyperParamSummary(UUID uuid);
TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid);
List<MappingDataset> getByModelMappingDataset(UUID uuid);
ModelMasterEntity findByModelByUUID(UUID uuid);
List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid);
List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid);
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
}

View File

@@ -5,11 +5,21 @@ import static com.kamco.cd.training.postgres.entity.QModelDatasetEntity.modelDat
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -17,8 +27,10 @@ import java.util.List;
import java.util.Optional;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Repository;
@Slf4j
@Repository
@RequiredArgsConstructor
public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
@@ -82,6 +94,41 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
.fetchOne();
}
@Override
public TransferHyperSummary getByModelTransferHyperParamSummary(UUID uuid) {
QModelMasterEntity subMaster = new QModelMasterEntity("subMaster");
QModelHyperParamEntity subHyper = new QModelHyperParamEntity("subHyper");
return queryFactory
.select(
Projections.constructor(
TransferHyperSummary.class,
modelHyperParamEntity.uuid,
modelHyperParamEntity.id,
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.backbone,
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
subHyper.uuid,
subHyper.id,
subHyper.hyperVer,
subHyper.backbone,
subHyper.inputSize,
subHyper.cropSize,
subHyper.batchSize))
.from(modelMasterEntity)
.innerJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(subMaster)
.on(subMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(subHyper)
.on(subHyper.id.eq(subMaster.hyperParamId))
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<MappingDataset> getByModelMappingDataset(UUID uuid) {
return queryFactory
@@ -116,4 +163,110 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelTrainMetrics.class,
modelMetricsTrainEntity.epoch,
modelMetricsTrainEntity.iteration,
modelMetricsTrainEntity.loss,
modelMetricsTrainEntity.lr,
modelMetricsTrainEntity.durationTime))
.from(modelMetricsTrainEntity)
.where(modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelValidationMetrics.class,
modelMetricsValidationEntity.epoch,
modelMetricsValidationEntity.aAcc,
modelMetricsValidationEntity.mFscore,
modelMetricsValidationEntity.mPrecision,
modelMetricsValidationEntity.mRecall,
modelMetricsValidationEntity.mIou,
modelMetricsValidationEntity.mAcc,
modelMetricsValidationEntity.changedFscore,
modelMetricsValidationEntity.changedPrecision,
modelMetricsValidationEntity.changedRecall,
modelMetricsValidationEntity.unchangedFscore,
modelMetricsValidationEntity.unchangedPrecision,
modelMetricsValidationEntity.unchangedRecall))
.from(modelMetricsValidationEntity)
.where(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelTestMetrics.class,
modelMetricsTestEntity.model1,
modelMetricsTestEntity.tp,
modelMetricsTestEntity.fp,
modelMetricsTestEntity.fn,
modelMetricsTestEntity.precisions,
modelMetricsTestEntity.recall,
modelMetricsTestEntity.f1Score,
modelMetricsTestEntity.accuracy,
modelMetricsTestEntity.iou,
modelMetricsTestEntity.detectionCount,
modelMetricsTestEntity.gtCount))
.from(modelMetricsTestEntity)
.where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return null;
}
return queryFactory
.select(
Projections.constructor(
ModelBestEpoch.class,
modelMetricsTrainEntity.epoch,
modelMetricsTrainEntity.loss,
modelMetricsValidationEntity.mFscore,
modelMetricsValidationEntity.mPrecision,
modelMetricsValidationEntity.mRecall,
modelMetricsValidationEntity.mIou,
modelMetricsValidationEntity.mAcc))
.from(modelMetricsTrainEntity)
.leftJoin(modelMetricsValidationEntity)
.on(
modelMetricsTrainEntity.model.eq(modelMetricsValidationEntity.model),
modelMetricsTrainEntity.epoch.eq(modelMetricsValidationEntity.epoch))
.where(
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()),
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
.fetchOne();
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
@@ -19,4 +20,8 @@ public interface ModelMngRepositoryCustom {
Optional<ModelMasterEntity> findByUuid(UUID uuid);
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
TrainRunRequest findTrainRunRequest(Long modelId);
Long findModelStep1InProgressCnt();
}

View File

@@ -1,10 +1,16 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import java.util.Optional;
@@ -40,6 +46,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
builder.and(modelMasterEntity.modelNo.eq(req.getModelNo()));
}
builder.and(modelMasterEntity.delYn.isFalse());
List<ModelMasterEntity> content =
queryFactory
.selectFrom(modelMasterEntity)
@@ -82,4 +90,71 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
return Optional.empty();
}
@Override
public TrainRunRequest findTrainRunRequest(Long modelId) {
return queryFactory
.select(
Projections.constructor(
TrainRunRequest.class,
modelMasterEntity.requestPath, // datasetFolder
modelMasterEntity.uuid, // outputFolder
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
modelHyperParamEntity.gpuIds,
modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone,
modelConfigEntity.epochCount,
modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers,
modelHyperParamEntity.trainShuffle,
modelHyperParamEntity.trainPersistent,
modelHyperParamEntity.valPersistent,
modelHyperParamEntity.dropPathRate,
modelHyperParamEntity.frozenStages,
modelHyperParamEntity.neckPolicy,
modelHyperParamEntity.classWeight,
modelHyperParamEntity.decoderChannels,
modelHyperParamEntity.weightDecay,
modelHyperParamEntity.layerDecayRate,
modelHyperParamEntity.ignoreIndex,
modelHyperParamEntity.ddpFindUnusedParams,
modelHyperParamEntity.numLayers,
modelHyperParamEntity.metrics,
modelHyperParamEntity.saveBest,
modelHyperParamEntity.saveBestRule,
modelHyperParamEntity.valInterval,
modelHyperParamEntity.logInterval,
modelHyperParamEntity.visInterval,
modelHyperParamEntity.rotProb,
modelHyperParamEntity.rotDegree,
modelHyperParamEntity.flipProb,
modelHyperParamEntity.exchangeProb,
modelHyperParamEntity.brightnessDelta,
modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class),
modelHyperParamEntity.uuid))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(modelConfigEntity)
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne();
}
@Override
public Long findModelStep1InProgressCnt() {
return queryFactory
.select(modelMasterEntity.id.count())
.from(modelMasterEntity)
.where(modelMasterEntity.step1State.eq(TrainStatusType.IN_PROGRESS.getId()))
.fetchOne();
}
}

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelTestMetricsJobRepository
extends JpaRepository<ModelMetricsTestEntity, Long>, ModelTestMetricsJobRepositoryCustom {}

View File

@@ -0,0 +1,13 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs);
}

View File

@@ -0,0 +1,70 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.jdbc.core.JdbcTemplate;
public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
implements ModelTestMetricsJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
private final JdbcTemplate jdbcTemplate;
public ModelTestMetricsJobRepositoryImpl(
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
super(ModelMetricsTestEntity.class);
this.queryFactory = queryFactory;
this.jdbcTemplate = jdbcTemplate;
}
@Override
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
queryFactory
.update(modelMasterEntity)
.set(
stepNo.equals("step1")
? modelMasterEntity.step1MetricSaveYn
: modelMasterEntity.step2MetricSaveYn,
true)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity)
.where(
modelMasterEntity.step2EndDttm.isNotNull(),
modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step2MetricSaveYn
.isNull()
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTest(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_metrics_test
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
detection_count, gt_count
)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
}

View File

@@ -0,0 +1,7 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import org.springframework.data.jpa.repository.JpaRepository;
public interface ModelTrainJobRepository
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}

View File

@@ -0,0 +1,10 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.Optional;
public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
}

View File

@@ -0,0 +1,43 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.Optional;
import org.springframework.stereotype.Repository;
@Repository
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
public ModelTrainJobRepositoryImpl(EntityManager em) {
this.queryFactory = new JPAQueryFactory(em);
}
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override
public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
return max != null ? max : 0;
}
/**
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
* 단순.
*/
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
return Optional.ofNullable(job);
}
}

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelTrainMetricsJobRepository
extends JpaRepository<ModelMetricsTrainEntity, Long>, ModelTrainMetricsJobRepositoryCustom {}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
public interface ModelTrainMetricsJobRepositoryCustom {
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
void insertModelMetricsTrain(List<Object[]> batchArgs);
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
void insertModelMetricsValidation(List<Object[]> batchArgs);
}

View File

@@ -0,0 +1,82 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.jdbc.core.JdbcTemplate;
public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySupport
implements ModelTrainMetricsJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
private final JdbcTemplate jdbcTemplate;
public ModelTrainMetricsJobRepositoryImpl(
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
super(ModelMetricsTrainEntity.class);
this.queryFactory = queryFactory;
this.jdbcTemplate = jdbcTemplate;
}
@Override
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
.from(modelMasterEntity)
.where(
modelMasterEntity.step1EndDttm.isNotNull(),
modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step1MetricSaveYn
.isNull()
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_metrics_train
(model_id, epoch, iteration, loss, lr, duration_time)
values (?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
queryFactory
.update(modelMasterEntity)
.set(
stepNo.equals("step1")
? modelMasterEntity.step1MetricSaveYn
: modelMasterEntity.step2MetricSaveYn,
true)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_metrics_validation
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
unchanged_fscore, unchanged_precision, unchanged_recall
)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
}

View File

@@ -0,0 +1,189 @@
package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/train")
public class TrainApiController {
private final TrainJobService trainJobService;
private final TestJobService testJobService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/run/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "재실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/restart/{uuid}")
public ApiResponseDto<String> restart(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.restart(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "이어하기 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/resume/{uuid}")
public ApiResponseDto<String> resume(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.resume(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/cancel/{uuid}")
public ApiResponseDto<String> cancel(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 실행", description = "test 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "test 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/run/{epoch}/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.enqueue(modelId, uuid, epoch);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/cancel/{uuid}")
public ApiResponseDto<String> cancelTest(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/create-tmp/{uuid}")
public ApiResponseDto<UUID> createTmpFile(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
}
}

View File

@@ -0,0 +1,16 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class EvalRunRequest {
private String uuid;
private int epoch; // best_changed_fscore_epoch_1.pth
private Integer timeoutSeconds;
}

View File

@@ -0,0 +1,25 @@
package com.kamco.cd.training.train.dto;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public class ModelTrainJobDto {
private Long id;
private Long modelId;
private Integer attemptNo;
private String statusCd;
private Integer exitCode;
private String errorMessage;
private String containerName;
private Map<String, Object> paramsJson;
private ZonedDateTime queuedDttm;
private ZonedDateTime startedDttm;
private ZonedDateTime finishedDttm;
private Integer totalEpoch;
private Integer currentEpoch;
}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.train.dto;
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
public class ModelTrainJobQueuedEvent {
private final Long jobId;
public ModelTrainJobQueuedEvent(Long jobId) {
this.jobId = jobId;
}
public Long getJobId() {
return jobId;
}
}

View File

@@ -0,0 +1,21 @@
package com.kamco.cd.training.train.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class ModelTrainMetricsDto {
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ResponsePathDto {
private Long modelId;
private String responsePath;
}
}

View File

@@ -0,0 +1,94 @@
package com.kamco.cd.training.train.dto;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunRequest {
// ========================
// 기본
// ========================
private String datasetFolder;
private UUID outputFolder;
private String inputSize;
private String cropSize;
private Integer batchSize;
private String gpuIds;
private Integer gpus;
private Double learningRate;
private String backbone;
private Integer epochs;
// ========================
// Data
// ========================
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// ========================
// Model Architecture
// ========================
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String classWeight;
private String decoderChannels;
// ========================
// Loss & Optimization
// ========================
private Double weightDecay;
private Double layerDecayRate;
private Integer ignoreIndex;
private Boolean ddpFindUnusedParams;
private Integer numLayers;
// ========================
// Evaluation
// ========================
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// ========================
// Augmentation
// ========================
private Double rotProb;
private String rotDegree;
private Double flipProb;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// ========================
// 실행 타임아웃
// ========================
private Integer timeoutSeconds;
private String resumeFrom;
private UUID uuid;
public String getOutputFolder() {
return String.valueOf(this.outputFolder);
}
public String getUuid() {
return String.valueOf(this.uuid);
}
}

View File

@@ -0,0 +1,20 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
/** 학습 실행 결과 반환 객체 */
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunResult {
private String jobId;
private String containerName;
private int exitCode;
private String status;
private String logs;
}

View File

@@ -0,0 +1,406 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Log4j2
@Service
public class DockerTrainService {
// 실행할 Docker 이미지명
@Value("${train.docker.image}")
private String image;
// 학습 요청 데이터가 위치한 호스트 디렉토리
@Value("${train.docker.requestDir}")
private String requestDir;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
// 컨테이너 이름 prefix
@Value("${train.docker.containerPrefix}")
private String containerPrefix;
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
@Value("${train.docker.shmSize:16g}")
private String shmSize;
// IPC host 사용 여부
@Value("${train.docker.ipcHost:true}")
private boolean ipcHost;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req);
log.info("=== Docker Train Command ===");
log.info("Container: {}", containerName);
log.info("Command: {}", String.join(" ", cmd));
log.info("================================");
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
Process p = pb.start();
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder logBuilder = new StringBuilder();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
// 1) 로그 누적
synchronized (logBuilder) {
logBuilder.append(line).append('\n');
}
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line);
if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
// TODO 실행중인 에폭 저장 필요하면 만들어야함
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch);
}
}
} catch (Exception e) {
log.warn("logThread error: {}", e.toString());
}
},
"train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true);
logThread.start();
int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
if (!finished) {
// docker run 프로세스도 같이 끊어야 readLine이 풀림
p.destroy();
if (!p.waitFor(2, TimeUnit.SECONDS)) {
p.destroyForcibly();
}
killContainer(containerName);
String logs;
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(
null, // jobId (없으면 null)
containerName,
-1,
"TIMEOUT",
logs);
}
int exit = p.exitValue();
// 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
logThread.join(500);
String logs;
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
}
/**
* 학습 docker run command
*
* @param containerName
* @param req
* @return
*/
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName + "-" + req.getUuid().substring(0, 8));
// 실행 종료 시 자동 삭제
c.add("--rm");
// GPU 전체 사용
c.add("--gpus");
c.add("all");
// IPC host 사용 여부
if (ipcHost) {
c.add("--ipc=host");
}
// 공유메모리 설정
c.add("--shm-size=" + shmSize);
// 메모리 관련 ulimit 설정
c.add("--ulimit");
c.add("memlock=-1");
c.add("--ulimit");
c.add("stack=67108864");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add(requestDir + ":/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
// 표준입력 유지 (-it 대신 -i만 사용)
c.add("-i");
// 사용할 이미지
c.add(image);
// ===== 컨테이너 내부 실행 명령 =====
c.add("python");
c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
// ===== Data =====
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
addArg(c, "--val-num-workers", req.getValNumWorkers());
addArg(c, "--test-num-workers", req.getTestNumWorkers());
addArg(c, "--train-shuffle", req.getTrainShuffle());
addArg(c, "--train-persistent", req.getTrainPersistent());
addArg(c, "--val-persistent", req.getValPersistent());
// ===== Model Architecture =====
addArg(c, "--drop-path-rate", req.getDropPathRate());
addArg(c, "--frozen-stages", req.getFrozenStages());
addArg(c, "--neck-policy", req.getNeckPolicy());
addArg(c, "--class-weight", req.getClassWeight());
addArg(c, "--decoder-channels", req.getDecoderChannels());
// ===== Loss & Optimization =====
addArg(c, "--weight-decay", req.getWeightDecay());
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
addArg(c, "--ignore-index", req.getIgnoreIndex());
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
addArg(c, "--num-layers", req.getNumLayers());
// ===== Evaluation =====
addArg(c, "--metrics", req.getMetrics());
addArg(c, "--save-best", req.getSaveBest());
addArg(c, "--save-best-rule", req.getSaveBestRule());
addArg(c, "--val-interval", req.getValInterval());
addArg(c, "--log-interval", req.getLogInterval());
addArg(c, "--vis-interval", req.getVisInterval());
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta());
addArg(c, "--resume-from", req.getResumeFrom());
return c;
}
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
private void addArg(List<String> c, String key, Object value) {
if (value == null) return;
String s = String.valueOf(value).trim();
if (s.isEmpty()) return;
c.add(key + "=" + s);
}
/** 컨테이너 강제 종료 및 제거 */
public void killContainer(String containerName) {
try {
new ProcessBuilder("docker", "rm", "-f", containerName)
.redirectErrorStream(true)
.start()
.waitFor(10, TimeUnit.SECONDS);
} catch (Exception ignored) {
}
}
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);
log.info("=== Docker Test Command ===");
log.info("Container: {}", containerName);
log.info("Command: {}", String.join(" ", cmd));
log.info("================================");
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
Process p = pb.start();
StringBuilder log = new StringBuilder();
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n');
}
}
} catch (Exception ignored) {
}
});
logThread.setDaemon(true);
logThread.start();
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
killContainer(containerName);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
}
int exit = p.exitValue();
logThread.join(500);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
}
/**
* 테스트 docker run command
*
* @param containerName
* @param req
* @return
*/
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
String uuid = req.getUuid();
Integer epoch = req.getEpoch();
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
c.add("--name");
c.add(containerName + "=" + req.getUuid().substring(0, 8));
c.add("--rm");
c.add("--gpus");
c.add("all");
if (ipcHost) c.add("--ipc=host");
c.add("--shm-size=" + shmSize);
c.add("-v");
c.add(requestDir + ":/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
c.add(image);
c.add("python");
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
c.add("--dataset_dir");
c.add("/data/" + uuid);
c.add("--model");
c.add("/checkpoints/" + uuid + "/" + modelFile);
return c;
}
}

View File

@@ -0,0 +1,102 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class ModelTestMetricsJobService {
private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService;
@Value("${spring.profiles.active}")
private String profile;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
if (modelIds.isEmpty()) {
return;
}
for (ResponsePathDto modelInfo : modelIds) {
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
String model = record.get("model");
long TP = Long.parseLong(record.get("TP"));
long FP = Long.parseLong(record.get("FP"));
long FN = Long.parseLong(record.get("FN"));
float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
float f1_score = Float.parseFloat(record.get("f1_score"));
float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
model,
TP,
FP,
FN,
precision,
recall,
f1_score,
accuracy,
iou,
detection_count,
gt_count
});
}
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
}
}
}

View File

@@ -0,0 +1,136 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class ModelTrainMetricsJobService {
private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService;
@Value("${spring.profiles.active}")
private String profile;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTrainValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
if (modelIds.isEmpty()) {
return;
}
for (ResponsePathDto modelInfo : modelIds) {
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
}
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
float aAcc = Float.parseFloat(record.get("aAcc"));
float mFscore = Float.parseFloat(record.get("mFscore"));
float mPrecision = Float.parseFloat(record.get("mPrecision"));
float mRecall = Float.parseFloat(record.get("mRecall"));
float mIoU = Float.parseFloat(record.get("mIoU"));
float mAcc = Float.parseFloat(record.get("mAcc"));
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
float changed_precision = Float.parseFloat(record.get("changed_precision"));
float changed_recall = Float.parseFloat(record.get("changed_recall"));
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
}
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
}
}
}

View File

@@ -0,0 +1,79 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TestJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
// 마스터 확인
modelTrainMngCoreService.findModelById(modelId);
// best epoch 업데이트
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL");
params.put("uuid", String.valueOf(uuid));
params.put("epoch", epoch);
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, params, ZonedDateTime.now());
// step2 시작으로 마킹
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
}

View File

@@ -0,0 +1,239 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.service.TmpDatasetService;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
private final TmpDatasetService tmpDatasetService;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
public Long getModelIdByUuid(UUID uuid) {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/** 실행 예약 (QUEUE 등록) */
@Transactional
public Long enqueue(Long modelId) {
// 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId);
// 파라미터 조회
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
if (trainRunRequest == null) {
throw new IllegalArgumentException("Model not found: " + modelId);
}
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
@SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
paramsMap.put("jobType", "TRAIN");
paramsMap.put("uuid", trainRunRequest.getUuid());
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 재시작
*
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
*/
@Transactional
public Long restart(Long modelId) {
return createNextAttempt(modelId, ResumeMode.NONE);
}
/**
* 이어하기
*
* @param modelId
* @return
*/
@Transactional
public Long resume(Long modelId) {
return createNextAttempt(modelId, ResumeMode.REQUIRE);
}
/**
* 중단
*
* <p>- job 상태 CANCELED - master 상태 STOPPED
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
throw new IllegalStateException("이미 진행중입니다.");
}
var lastJob =
modelTrainJobCoreService
.findLatestByModelId(modelId)
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
// 이전 params_json 재사용 (재현성)
Map<String, Object> params = lastJob.getParamsJson();
if (params == null || params.isEmpty()) {
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
}
// mode에 따라 resume 옵션 주입/제거
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
if (mode == ResumeMode.NONE) {
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
nextParams.remove("resumeFrom");
nextParams.remove("resume");
} else if (mode == ResumeMode.REQUIRE) {
// 체크포인트 탐지해서 resumeFrom 세팅
String resumeFrom = findResumeFromOrNull(nextParams);
if (resumeFrom == null) {
throw new IllegalStateException("이어하기 체크포인트가 없습니다.");
}
nextParams.put("resumeFrom", resumeFrom);
nextParams.put("resume", true);
}
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
private enum ResumeMode {
NONE, // 새로 시작
REQUIRE // 이어하기
}
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
Object out = paramsJson.get("outputFolder");
if (out == null) return null;
String outputFolder = String.valueOf(out).trim(); // uuid
if (outputFolder.isEmpty()) return null;
// 호스트 기준 경로
Path outDir = Paths.get(responseDir, outputFolder);
Path last = outDir.resolve("last_checkpoint");
if (!Files.isRegularFile(last)) return null;
try {
String ckptFile = Files.readString(last).trim(); // epoch_10.pth
if (ckptFile.isEmpty()) return null;
Path ckptHost = outDir.resolve(ckptFile);
if (!Files.isRegularFile(ckptHost)) return null;
// 컨테이너 경로 반환
return "/checkpoints/" + outputFolder + "/" + ckptFile;
} catch (Exception e) {
return null;
}
}
public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);
}
return modelUuid;
}
}

View File

@@ -0,0 +1,126 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
@Component
@RequiredArgsConstructor
public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
Long jobId = event.getJobId();
ModelTrainJobDto job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
return;
}
Long modelId = job.getModelId();
Map<String, Object> params = job.getParamsJson();
String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
boolean isEval = "EVAL".equals(jobType);
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
}
}
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
try {
TrainRunResult result;
if (isEval) {
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
result = dockerTrainService.runEvalSync(evalReq, containerName);
} else {
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params);
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
ModelTrainJobDto latest =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
return;
}
if (result.getExitCode() == 0) {
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) {
modelTrainMngCoreService.markStep2Success(modelId);
} else {
modelTrainMngCoreService.markStep1Success(modelId);
}
} else {
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) {
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
}
} catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
if ("EVAL".equals(params.get("jobType"))) {
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}
}
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
if (paramsJson == null || paramsJson.isEmpty()) {
return null;
}
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
}
}

View File

@@ -199,9 +199,6 @@ public class UploadDto {
private String fileName;
public double getUploadRate() {
if (this.chunkTotalIndex == 0) {
return 0.0;
}
return (double) (this.chunkIndex + 1) / (this.chunkTotalIndex + 1) * 100.0;
}
}

View File

@@ -67,8 +67,8 @@ public class UploadService {
String status = "UPLOADING";
if (uploadDivi.equals("dataset")) {
tmpDataSetDir = datasetTmpDir + uuid + "/";
fianlDir = datasetDir + uuid + "/";
tmpDataSetDir = datasetTmpDir;
fianlDir = datasetDir;
}
// 리턴용 파일 값
@@ -233,6 +233,9 @@ public class UploadService {
try {
FIleChecker.deleteFolder(tmpDir);
// 108 에서 86 서버로 이동
// log.info("################# server move 108 -> 86");
// FIleChecker.uploadTo86(outputPath);
} catch (Exception e) {
log.warn("tmpDir delete failed (merge already succeeded): tmpDir={}", tmpDir, e);
}

View File

@@ -55,6 +55,14 @@ file:
sync-tmp-dir: ${file.sync-root-dir}tmp/
sync-file-extention: tfw,tif
# dataset-dir: /kamco-nfs/dataset/upload/
dataset-dir: /home/kcomu/data/upload/
dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/
train:
docker:
image: "kamco-cd-train:love_latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"
shmSize: "16g"
ipcHost: true

View File

@@ -4,7 +4,7 @@ spring:
on-profile: prod
jpa:
show-sql: false
show-sql: true
hibernate:
ddl-auto: validate
properties:
@@ -12,34 +12,45 @@ spring:
default_batch_fetch_size: 100 # ✅ 성능 - N+1 쿼리 방지
order_updates: true # ✅ 성능 - 업데이트 순서 정렬로 데드락 방지
use_sql_comments: true # ⚠️ 선택 - SQL에 주석 추가 (디버깅용)
format_sql: true # ⚠️ 선택 - SQL 포맷팅 (가독성)
datasource:
url: jdbc:postgresql://10.100.0.10:25432/temp
username: temp
password: temp123!
url: jdbc:postgresql://127.0.01:15432/kamco_training_db
# url: jdbc:postgresql://localhost:15432/kamco_training_db
username: kamco_training_user
password: kamco_training_user_2025_!@#
hikari:
minimum-idle: 10
maximum-pool-size: 20
connection-timeout: 60000 # 60초 연결 타임아웃
idle-timeout: 300000 # 5분 유휴 타임아웃
max-lifetime: 1800000 # 30분 최대 수명
leak-detection-threshold: 60000 # 연결 누수 감지
transaction:
default-timeout: 300 # 5분 트랜잭션 타임아웃
jwt:
secret: "kamco_token_prod_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a"
secret: "kamco_token_dev_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a"
access-token-validity-in-ms: 86400000 # 1일
refresh-token-validity-in-ms: 604800000 # 7일
token:
refresh-cookie-name: kamco # 개발용 쿠키 이름
refresh-cookie-secure: true # 로컬 http 테스트면 false
refresh-cookie-secure: false # 로컬 http 테스트면 false
springdoc:
swagger-ui:
persist-authorization: true # 스웨거 새로고침해도 토큰 유지, 로컬스토리지에 저장
member:
init_password: kamco1234!
swagger:
local-port: 9080
file:
sync-root-dir: /app/original-images/
sync-tmp-dir: ${file.sync-root-dir}tmp/
sync-file-extention: tfw,tif
dataset-dir: /kamco-nfs/dataset/upload/
dataset-tmp-dir: ${file.dataset-dir}tmp/
train:
docker:
image: "kamco-cd-train:latest"
requestDir: "/home/kcomu/data/request"
responseDir: "/home/kcomu/data/response"
containerPrefix: "kamco-cd-train"
shmSize: "16g"
ipcHost: true

View File

@@ -9667,7 +9667,7 @@ INSERT INTO public.tb_audit_log VALUES (1813, 3, 'CREATE', 'SUCCESS', 'SYSTEM',
INSERT INTO public.tb_audit_log VALUES (1814, 3, 'READ', 'SUCCESS', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:44:49.832926+00', NULL, '', '/api/models/train');
INSERT INTO public.tb_audit_log VALUES (1815, NULL, 'CREATE', 'FAILED', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:53:40.899603+00', 467, '{
"username": "1234567",
"password":"****"
"password":"****"
}', '/api/auth/signin');
INSERT INTO public.tb_audit_log VALUES (1816, NULL, 'CREATE', 'SUCCESS', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:55:27.731595+00', NULL, '{
"username": "1234567",
@@ -30396,6 +30396,8 @@ ALTER TABLE ONLY public.tb_menu
ADD CONSTRAINT fksw914diut87r7lfykekc7xm2a FOREIGN KEY (parent_menu_uid) REFERENCES public.tb_menu(menu_uid);
-- Completed on 2025-12-26 16:11:11 KST
--
@@ -30404,3 +30406,5 @@ ALTER TABLE ONLY public.tb_menu
\unrestrict IYrUYfSgA4Fo2gubHcb84jDXfbBZEIiOZnyLtZgnMi641GaRQa5QDogarpTr7IG