Compare commits
65 Commits
41911014c9
...
feat/dean/
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c34ea7dcb | |||
| 3547c28361 | |||
| 6c70bfed18 | |||
| 95a75e63f4 | |||
| 2a1dbee290 | |||
| 384a321bf3 | |||
| f4e97d389b | |||
| 590810ff0a | |||
| a01c872982 | |||
| 905a245070 | |||
| 860ce35a8f | |||
| 7f3f5dca40 | |||
| 4a0a4e35ed | |||
| ae055dca1e | |||
| 26e8e1492f | |||
| 8fa722011c | |||
| 17d47d6200 | |||
| e178f58fe2 | |||
| cd0cf5726d | |||
| 8e4bea53da | |||
| 7a22d8ba73 | |||
| 2df4a7a80b | |||
| b451f697bc | |||
| 7e9c867f34 | |||
| 130e85f8a1 | |||
| 9e713cb49d | |||
| 51dfa97900 | |||
| 87c6b599b4 | |||
| f50855a822 | |||
| 8d416317a8 | |||
| 22aa071476 | |||
| a83bd09f8f | |||
| 96035f864a | |||
| fd7dfd7e7f | |||
| 190b93bee8 | |||
| c5f19cc961 | |||
| c56c0ca605 | |||
| c6e721aa37 | |||
| 6572e17f00 | |||
| be6365807c | |||
| d2fff7dfde | |||
| f66bc22c95 | |||
| 3367d0e7be | |||
| 352ec6ccb0 | |||
| 6a989255a3 | |||
| 878b21573f | |||
| 0602db1436 | |||
| 2f8bd1f98c | |||
| 75231ccbba | |||
| 1249a80da5 | |||
| 00c78eb42f | |||
| 35767adba1 | |||
| 47a2a159ef | |||
| 95548223cd | |||
| 2debdc5312 | |||
| 207cc47f1b | |||
| b6338bce8e | |||
| 2cfa2adcf5 | |||
| d7e19abfc9 | |||
| c843703ee7 | |||
| 133ea6b1ba | |||
| 0df977ae81 | |||
| 3e39006822 | |||
| 3ec1a71406 | |||
| 16009f1623 |
43
build.gradle
43
build.gradle
@@ -3,6 +3,7 @@ plugins {
|
||||
id 'org.springframework.boot' version '3.5.7'
|
||||
id 'io.spring.dependency-management' version '1.1.7'
|
||||
id 'com.diffplug.spotless' version '6.25.0'
|
||||
id 'idea'
|
||||
}
|
||||
|
||||
group = 'com.kamco.cd'
|
||||
@@ -21,11 +22,23 @@ configurations {
|
||||
}
|
||||
}
|
||||
|
||||
// QueryDSL 생성된 소스 디렉토리 정의
|
||||
def generatedSourcesDir = file("$buildDir/generated/sources/annotationProcessor/java/main")
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
maven { url "https://repo.osgeo.org/repository/release/" }
|
||||
}
|
||||
|
||||
// Gradle이 생성된 소스를 컴파일 경로에 포함하도록 설정
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
srcDirs += generatedSourcesDir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
|
||||
implementation 'org.springframework.boot:spring-boot-starter-web'
|
||||
@@ -84,7 +97,22 @@ dependencies {
|
||||
implementation 'org.reflections:reflections:0.10.2'
|
||||
|
||||
implementation 'com.jcraft:jsch:0.1.55'
|
||||
implementation 'org.apache.commons:commons-csv:1.10.0'
|
||||
}
|
||||
|
||||
// IntelliJ가 생성된 소스를 인식하도록 설정
|
||||
idea {
|
||||
module {
|
||||
// 소스 디렉토리로 인식
|
||||
sourceDirs += generatedSourcesDir
|
||||
|
||||
// Generated Sources Root로 마킹 (IntelliJ에서 특별 처리)
|
||||
generatedSourceDirs += generatedSourcesDir
|
||||
|
||||
// 소스 및 Javadoc 다운로드
|
||||
downloadJavadoc = true
|
||||
downloadSources = true
|
||||
}
|
||||
}
|
||||
|
||||
configurations.configureEach {
|
||||
@@ -95,6 +123,21 @@ tasks.named('test') {
|
||||
useJUnitPlatform()
|
||||
}
|
||||
|
||||
// 컴파일 전 생성된 소스 디렉토리 생성 보장
|
||||
tasks.named('compileJava') {
|
||||
doFirst {
|
||||
generatedSourcesDir.mkdirs()
|
||||
}
|
||||
}
|
||||
|
||||
// 생성된 소스 정리 태스크
|
||||
tasks.register('cleanGeneratedSources', Delete) {
|
||||
delete generatedSourcesDir
|
||||
}
|
||||
|
||||
tasks.named('clean') {
|
||||
dependsOn 'cleanGeneratedSources'
|
||||
}
|
||||
|
||||
bootJar {
|
||||
archiveFileName = 'ROOT.jar'
|
||||
|
||||
@@ -14,6 +14,7 @@ services:
|
||||
- /mnt/nfs_share/images:/app/original-images
|
||||
- /mnt/nfs_share/model_output:/app/model-outputs
|
||||
- /mnt/nfs_share/train_dataset:/app/train-dataset
|
||||
- /home/kcomu/data:/home/kcomu/data
|
||||
networks:
|
||||
- kamco-cds
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -2,8 +2,10 @@ package com.kamco.cd.training;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
|
||||
@EnableAsync
|
||||
@SpringBootApplication
|
||||
@EnableScheduling
|
||||
public class KamcoTrainingApplication {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.common.dto;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
@@ -14,6 +15,10 @@ public class HyperParam {
|
||||
// -------------------------
|
||||
// Important
|
||||
// -------------------------
|
||||
|
||||
@Schema(description = "모델", example = "large")
|
||||
private ModelType model; // backbone
|
||||
|
||||
@Schema(description = "백본 네트워크", example = "large")
|
||||
private String backbone; // backbone
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.common.enums;
|
||||
|
||||
import com.kamco.cd.training.common.utils.enums.CodeExpose;
|
||||
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||
import java.util.Arrays;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@@ -15,6 +16,10 @@ public enum ModelType implements EnumType {
|
||||
|
||||
private String desc;
|
||||
|
||||
public static ModelType getValueData(String modelNo) {
|
||||
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return name();
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.common.utils;
|
||||
|
||||
import static java.lang.String.CASE_INSENSITIVE_ORDER;
|
||||
|
||||
import com.jcraft.jsch.ChannelExec;
|
||||
import com.jcraft.jsch.ChannelSftp;
|
||||
import com.jcraft.jsch.JSch;
|
||||
import com.jcraft.jsch.Session;
|
||||
@@ -505,11 +506,15 @@ public class FIleChecker {
|
||||
|
||||
try {
|
||||
File dir = new File(targetPath);
|
||||
log.info("targetPath={}", targetPath);
|
||||
log.info("absolute targetPath={}", dir.getAbsolutePath());
|
||||
|
||||
if (!dir.exists()) {
|
||||
dir.mkdirs();
|
||||
}
|
||||
|
||||
File dest = new File(dir, String.valueOf(chunkIndex));
|
||||
log.info("real save path = {}", dest.getAbsolutePath());
|
||||
|
||||
log.info("chunkIndex={}, uploadSize={}", chunkIndex, mfile.getSize());
|
||||
log.info("savedSize={}", dest.length());
|
||||
@@ -521,6 +526,9 @@ public class FIleChecker {
|
||||
log.info("after delete={}", dest.length());
|
||||
mfile.transferTo(dest);
|
||||
|
||||
log.info("after transfer size={}", dest.length());
|
||||
log.info("after transfer exists={}", dest.exists());
|
||||
|
||||
return true;
|
||||
} catch (IOException e) {
|
||||
log.error("chunk save error", e);
|
||||
@@ -706,12 +714,17 @@ public class FIleChecker {
|
||||
}
|
||||
|
||||
public static void unzip(String fileName, String destDirectory) throws IOException {
|
||||
File destDir = new File(destDirectory);
|
||||
if (!destDir.exists()) {
|
||||
destDir.mkdirs(); // 대상 폴더가 없으면 생성
|
||||
}
|
||||
String zipFilePath = destDirectory + File.separator + fileName;
|
||||
|
||||
String zipFilePath = destDirectory + "/" + fileName;
|
||||
// zip 이름으로 폴더 생성 (확장자 제거)
|
||||
String folderName =
|
||||
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
|
||||
|
||||
File destDir = new File(destDirectory, folderName);
|
||||
|
||||
if (!destDir.exists()) {
|
||||
destDir.mkdirs();
|
||||
}
|
||||
|
||||
try (ZipInputStream zis = new ZipInputStream(new FileInputStream(zipFilePath))) {
|
||||
ZipEntry zipEntry = zis.getNextEntry();
|
||||
@@ -800,4 +813,97 @@ public class FIleChecker {
|
||||
if (session != null) session.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
public static void unzipOn86Server(String zipPath, String targetDir) {
|
||||
|
||||
String host = "192.168.2.86";
|
||||
String user = "kcomu";
|
||||
String password = "Kamco2025!";
|
||||
|
||||
Session session = null;
|
||||
ChannelExec channel = null;
|
||||
|
||||
try {
|
||||
JSch jsch = new JSch();
|
||||
|
||||
session = jsch.getSession(user, host, 22);
|
||||
session.setPassword(password);
|
||||
|
||||
Properties config = new Properties();
|
||||
config.put("StrictHostKeyChecking", "no");
|
||||
session.setConfig(config);
|
||||
|
||||
session.connect(10_000);
|
||||
|
||||
String command = "unzip -o " + zipPath + " -d " + targetDir;
|
||||
|
||||
channel = (ChannelExec) session.openChannel("exec");
|
||||
channel.setCommand(command);
|
||||
channel.setErrStream(System.err);
|
||||
|
||||
InputStream in = channel.getInputStream();
|
||||
channel.connect();
|
||||
|
||||
// 출력 읽기(선택)
|
||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
|
||||
while (br.readLine() != null) {
|
||||
// 필요하면 로그
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
if (channel != null) channel.disconnect();
|
||||
if (session != null) session.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
public static List<String> execCommandAndReadLines(String command) {
|
||||
|
||||
List<String> result = new ArrayList<>();
|
||||
|
||||
String host = "192.168.2.86";
|
||||
String user = "kcomu";
|
||||
String password = "Kamco2025!";
|
||||
|
||||
Session session = null;
|
||||
ChannelExec channel = null;
|
||||
|
||||
try {
|
||||
JSch jsch = new JSch();
|
||||
|
||||
session = jsch.getSession(user, host, 22);
|
||||
session.setPassword(password);
|
||||
|
||||
Properties config = new Properties();
|
||||
config.put("StrictHostKeyChecking", "no");
|
||||
session.setConfig(config);
|
||||
|
||||
session.connect(10_000);
|
||||
|
||||
channel = (ChannelExec) session.openChannel("exec");
|
||||
channel.setCommand(command);
|
||||
channel.setInputStream(null);
|
||||
|
||||
InputStream in = channel.getInputStream();
|
||||
channel.connect();
|
||||
|
||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
result.add(line);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("remote command failed : " + command, e);
|
||||
|
||||
} finally {
|
||||
if (channel != null) channel.disconnect();
|
||||
if (session != null) session.disconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,15 +14,11 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.validation.Valid;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.core.io.UrlResource;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
@@ -230,10 +226,15 @@ public class DatasetApiController {
|
||||
throws Exception {
|
||||
|
||||
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
|
||||
Path filePath = Paths.get(path);
|
||||
return datasetService.getFilePathByFile(path);
|
||||
}
|
||||
|
||||
Resource resource = new UrlResource(filePath.toUri());
|
||||
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
|
||||
@GetMapping("/files-to86")
|
||||
public ResponseEntity<Resource> getFileTo86(
|
||||
@RequestParam UUID uuid, @RequestParam String pathType) throws Exception {
|
||||
|
||||
return ResponseEntity.ok().contentType(MediaType.APPLICATION_OCTET_STREAM).body(resource);
|
||||
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
|
||||
return datasetService.getFilePathByFile(path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,6 +261,7 @@ public class DatasetDto {
|
||||
this.datasetId = datasetId;
|
||||
this.uuid = uuid;
|
||||
this.dataType = dataType;
|
||||
this.dataTypeName = getDataTypeName(dataType);
|
||||
this.title = title;
|
||||
this.roundNo = roundNo;
|
||||
this.compareYyyy = compareYyyy;
|
||||
|
||||
@@ -21,6 +21,7 @@ import com.kamco.cd.training.dataset.dto.DatasetObjDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.DatasetCoreService;
|
||||
import jakarta.validation.Valid;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
@@ -33,8 +34,13 @@ import java.util.stream.Stream;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@@ -158,6 +164,44 @@ public class DatasetService {
|
||||
}
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
@Transactional
|
||||
public ResponseObj insertDatasetTo86(@Valid AddReq addReq) {
|
||||
|
||||
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
|
||||
|
||||
// 압축 해제
|
||||
FIleChecker.unzipOn86Server(
|
||||
addReq.getFilePath() + addReq.getFileName(),
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
|
||||
|
||||
// 해제한 폴더 읽어서 데이터 저장
|
||||
List<Map<String, Object>> list =
|
||||
getUnzipDatasetFilesTo86(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "train");
|
||||
|
||||
int idx = 0;
|
||||
for (Map<String, Object> map : list) {
|
||||
datasetUid =
|
||||
this.insertTrainTestData(map, addReq, idx, datasetUid, "train"); // train 데이터 insert
|
||||
idx++;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> testList =
|
||||
getUnzipDatasetFilesTo86(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
|
||||
|
||||
int testIdx = 0;
|
||||
for (Map<String, Object> test : testList) {
|
||||
datasetUid =
|
||||
this.insertTrainTestData(test, addReq, testIdx, datasetUid, "test"); // test 데이터 insert
|
||||
testIdx++;
|
||||
}
|
||||
|
||||
datasetCoreService.updateDatasetUploadStatus(datasetUid);
|
||||
return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다.");
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public ResponseObj insertDataset(@Valid AddReq addReq) {
|
||||
|
||||
@@ -179,6 +223,17 @@ public class DatasetService {
|
||||
idx++;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> valList =
|
||||
getUnzipDatasetFiles(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "val");
|
||||
|
||||
int valIdx = 0;
|
||||
for (Map<String, Object> valid : valList) {
|
||||
datasetUid =
|
||||
this.insertTrainTestData(valid, addReq, valIdx, datasetUid, "val"); // val 데이터 insert
|
||||
valIdx++;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> testList =
|
||||
getUnzipDatasetFiles(
|
||||
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
|
||||
@@ -285,6 +340,8 @@ public class DatasetService {
|
||||
|
||||
if (subDir.equals("train")) {
|
||||
datasetCoreService.insertDatasetObj(objRegDto);
|
||||
} else if (subDir.equals("val")) {
|
||||
datasetCoreService.insertDatasetValObj(objRegDto);
|
||||
} else {
|
||||
datasetCoreService.insertDatasetTestObj(objRegDto);
|
||||
}
|
||||
@@ -356,4 +413,112 @@ public class DatasetService {
|
||||
public String getFilePathByUUIDPathType(UUID uuid, String pathType) {
|
||||
return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
private List<Map<String, Object>> getUnzipDatasetFilesTo86(String unzipRootPath, String subDir) {
|
||||
|
||||
// String root = Paths.get(unzipRootPath)
|
||||
// .resolve(subDir)
|
||||
// .toString();
|
||||
//
|
||||
String root = normalizeLinuxPath(unzipRootPath + "/" + subDir);
|
||||
|
||||
Map<String, Map<String, Object>> grouped = new HashMap<>();
|
||||
|
||||
for (String dirName : LABEL_DIRS) {
|
||||
|
||||
String remoteDir = root + "/" + dirName;
|
||||
|
||||
// 1. 86 서버에서 해당 디렉토리의 파일 목록 조회
|
||||
List<String> files = listFilesOn86Server(remoteDir);
|
||||
|
||||
if (files.isEmpty()) {
|
||||
throw new IllegalStateException("폴더가 존재하지 않거나 파일이 없습니다 : " + remoteDir);
|
||||
}
|
||||
|
||||
for (String fullPath : files) {
|
||||
|
||||
String fileName = Paths.get(fullPath).getFileName().toString();
|
||||
String baseName = getBaseName(fileName);
|
||||
|
||||
Map<String, Object> data = grouped.computeIfAbsent(baseName, k -> new HashMap<>());
|
||||
|
||||
data.put("baseName", baseName);
|
||||
|
||||
if ("label-json".equals(dirName)) {
|
||||
|
||||
// 2. json 내용도 86 서버에서 읽어서 가져와야 함
|
||||
String json = readRemoteFileAsString(fullPath);
|
||||
|
||||
data.put("label-json", parseJson(json));
|
||||
data.put("geojson_path", fullPath);
|
||||
|
||||
} else {
|
||||
|
||||
data.put(dirName, fullPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrayList<>(grouped.values());
|
||||
}
|
||||
|
||||
private List<String> listFilesOn86Server(String remoteDir) {
|
||||
|
||||
String command = "find " + escape(remoteDir) + " -maxdepth 1 -type f";
|
||||
|
||||
return FIleChecker.execCommandAndReadLines(command);
|
||||
}
|
||||
|
||||
private String readRemoteFileAsString(String remoteFilePath) {
|
||||
|
||||
String command = "cat " + escape(remoteFilePath);
|
||||
|
||||
List<String> lines = FIleChecker.execCommandAndReadLines(command);
|
||||
|
||||
return String.join("\n", lines);
|
||||
}
|
||||
|
||||
private JsonNode parseJson(String json) {
|
||||
try {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
return mapper.readTree(json);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("JSON 파싱 실패", e);
|
||||
}
|
||||
}
|
||||
|
||||
private String escape(String path) {
|
||||
return "'" + path.replace("'", "'\"'\"'") + "'";
|
||||
}
|
||||
|
||||
private static String normalizeLinuxPath(String path) {
|
||||
return path.replace("\\", "/");
|
||||
}
|
||||
|
||||
public ResponseEntity<Resource> getFilePathByFile(String remoteFilePath) {
|
||||
|
||||
try {
|
||||
Path path = Paths.get(remoteFilePath);
|
||||
InputStream inputStream = Files.newInputStream(path);
|
||||
|
||||
InputStreamResource resource =
|
||||
new InputStreamResource(inputStream) {
|
||||
@Override
|
||||
public long contentLength() {
|
||||
return -1; // 알 수 없으면 -1
|
||||
}
|
||||
};
|
||||
|
||||
String fileName = Paths.get(remoteFilePath.replace("\\", "/")).getFileName().toString();
|
||||
|
||||
return ResponseEntity.ok()
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
|
||||
.contentType(MediaType.APPLICATION_OCTET_STREAM)
|
||||
.body(resource);
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.hyperparam;
|
||||
|
||||
import com.kamco.cd.training.common.dto.HyperParam;
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
||||
@@ -65,7 +66,7 @@ public class HyperParamApiController {
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
|
||||
@ApiResponse(responseCode = "422", description = "HPs_0001 수정 불가", content = @Content),
|
||||
@ApiResponse(responseCode = "422", description = "default는 삭제불가", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PutMapping("/{uuid}")
|
||||
@@ -87,8 +88,9 @@ public class HyperParamApiController {
|
||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/list")
|
||||
@GetMapping("{model}/list")
|
||||
public ApiResponseDto<Page<List>> getHyperParam(
|
||||
@PathVariable ModelType model,
|
||||
@Parameter(
|
||||
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
||||
example = "CREATE_DATE")
|
||||
@@ -98,7 +100,7 @@ public class HyperParamApiController {
|
||||
LocalDate startDate,
|
||||
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false)
|
||||
LocalDate endDate,
|
||||
@Parameter(description = "버전명", example = "HPs_0001") @RequestParam(required = false)
|
||||
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
||||
String hyperVer,
|
||||
@Parameter(
|
||||
description = "정렬",
|
||||
@@ -124,7 +126,7 @@ public class HyperParamApiController {
|
||||
searchReq.setSort(sort);
|
||||
searchReq.setPage(page);
|
||||
searchReq.setSize(size);
|
||||
Page<List> list = hyperParamService.getHyperParamList(searchReq);
|
||||
Page<List> list = hyperParamService.getHyperParamList(model, searchReq);
|
||||
|
||||
return ApiResponseDto.ok(list);
|
||||
}
|
||||
@@ -133,7 +135,7 @@ public class HyperParamApiController {
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
|
||||
@ApiResponse(responseCode = "422", description = "HPs_0001 삭제 불가", content = @Content),
|
||||
@ApiResponse(responseCode = "422", description = "default 삭제 불가", content = @Content),
|
||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||
})
|
||||
@DeleteMapping("/{uuid}")
|
||||
@@ -179,8 +181,11 @@ public class HyperParamApiController {
|
||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/init")
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam() {
|
||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam());
|
||||
@GetMapping("/init/{model}")
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
|
||||
@PathVariable ModelType model
|
||||
|
||||
) {
|
||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.hyperparam.dto;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.common.utils.enums.CodeExpose;
|
||||
import com.kamco.cd.training.common.utils.enums.EnumType;
|
||||
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
|
||||
@@ -24,6 +25,7 @@ public class HyperParamDto {
|
||||
@AllArgsConstructor
|
||||
public static class Basic {
|
||||
|
||||
private ModelType model; // 20250212 modeltype추가
|
||||
private UUID uuid;
|
||||
private String hyperVer;
|
||||
@JsonFormatDttm private ZonedDateTime createdDttm;
|
||||
@@ -98,6 +100,8 @@ public class HyperParamDto {
|
||||
private Integer gpuCnt;
|
||||
private String gpuIds;
|
||||
private Integer masterPort;
|
||||
|
||||
private Boolean isDefault;
|
||||
}
|
||||
|
||||
@Getter
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package com.kamco.cd.training.hyperparam.service;
|
||||
|
||||
import com.kamco.cd.training.common.dto.HyperParam;
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -20,11 +22,12 @@ public class HyperParamService {
|
||||
/**
|
||||
* 하이퍼 파라미터 목록 조회
|
||||
*
|
||||
* @param model
|
||||
* @param req
|
||||
* @return 목록
|
||||
*/
|
||||
public Page<List> getHyperParamList(HyperParamDto.SearchReq req) {
|
||||
return hyperParamCoreService.findByHyperVerList(req);
|
||||
public Page<List> getHyperParamList(ModelType model, SearchReq req) {
|
||||
return hyperParamCoreService.findByHyperVerList(model, req);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -59,8 +62,8 @@ public class HyperParamService {
|
||||
}
|
||||
|
||||
/** 하이퍼파라미터 최적화 설정값 조회 */
|
||||
public HyperParamDto.Basic getInitHyperParam() {
|
||||
return hyperParamCoreService.getInitHyperParam();
|
||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||
return hyperParamCoreService.getInitHyperParam(model);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.kamco.cd.training.model;
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.model.service.ModelTrainDetailService;
|
||||
@@ -132,4 +136,90 @@ public class ModelTrainDetailApiController {
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/metrics/train/{uuid}")
|
||||
public ApiResponseDto<List<ModelTrainMetrics>> getModelTrainMetricResult(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainMetricResult(uuid));
|
||||
}
|
||||
|
||||
@Operation(
|
||||
summary = "모델관리 > 모델 상세 > 성능 정보 (Validation)",
|
||||
description = "모델 상세 > 성능 정보 (Validation) API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/metrics/validation/{uuid}")
|
||||
public ApiResponseDto<List<ModelValidationMetrics>> getModelValidationMetricResult(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelValidationMetricResult(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/metrics/test/{uuid}")
|
||||
public ApiResponseDto<List<ModelTestMetrics>> getModelTestMetricResult(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTestMetricResult(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/best-epoch/{uuid}")
|
||||
public ApiResponseDto<ModelBestEpoch> getModelTrainBestEpoch(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,13 +74,12 @@ public class ModelTrainMngApiController {
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
|
||||
@ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content)
|
||||
@ApiResponse(responseCode = "409", description = "G1_000001 삭제 불가", content = @Content)
|
||||
})
|
||||
@DeleteMapping("/{uuid}")
|
||||
public ApiResponseDto<Void> deleteModelTrain(
|
||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
@PathVariable UUID uuid) {
|
||||
modelTrainMngService.deleteModelTrain(uuid);
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
@@ -92,9 +91,8 @@ public class ModelTrainMngApiController {
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping
|
||||
public ApiResponseDto<String> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
|
||||
modelTrainMngService.createModelTrain(req);
|
||||
return ApiResponseDto.ok("ok");
|
||||
public ApiResponseDto<UUID> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
|
||||
return ApiResponseDto.ok(modelTrainMngService.createModelTrain(req));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
|
||||
@@ -150,4 +148,22 @@ public class ModelTrainMngApiController {
|
||||
req.setDataType(selectType);
|
||||
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
|
||||
}
|
||||
|
||||
@Operation(summary = "모델학습 1단계 실행중인 것이 있는지 count", description = "모델학습 1단계 실행중인 것이 있는지 count")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "검색 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = Long.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/ing-training-cnt")
|
||||
public ApiResponseDto<Long> findModelStep1InProgressCnt() {
|
||||
return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,4 +180,69 @@ public class ModelTrainDetailDto {
|
||||
private TransferHyperSummary modelTrainHyper;
|
||||
private List<SelectDataSet> modelTrainDataset;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelTrainMetrics {
|
||||
private Integer epoch;
|
||||
private Long iteration;
|
||||
private Double loss;
|
||||
private Double lr;
|
||||
private Float durationTime;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelValidationMetrics {
|
||||
|
||||
private Integer epoch;
|
||||
private Float aAcc;
|
||||
private Float mFscore;
|
||||
private Float mPrecision;
|
||||
private Float mRecall;
|
||||
private Float mIou;
|
||||
private Float mAcc;
|
||||
private Float changedFscore;
|
||||
private Float changedPrecision;
|
||||
private Float changedRecall;
|
||||
private Float unchangedFscore;
|
||||
private Float unchangedPrecision;
|
||||
private Float unchangedRecall;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelTestMetrics {
|
||||
private String model;
|
||||
private Long tp;
|
||||
private Long fp;
|
||||
private Long fn;
|
||||
private Float precision;
|
||||
private Float recall;
|
||||
private Float f1Score;
|
||||
private Float accuracy;
|
||||
private Float iou;
|
||||
private Long detectionCount;
|
||||
private Long gtCount;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelBestEpoch {
|
||||
private Integer epoch;
|
||||
private Double loss;
|
||||
private Float f1Score;
|
||||
private Float precision;
|
||||
private Float recall;
|
||||
private Float iou;
|
||||
private Float accuracy;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ public class ModelTrainMngDto {
|
||||
private String statusCd;
|
||||
private String trainType;
|
||||
private String modelNo;
|
||||
private Long currentAttemptId;
|
||||
|
||||
public String getStatusName() {
|
||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||
@@ -154,6 +155,17 @@ public class ModelTrainMngDto {
|
||||
ModelConfig modelConfig;
|
||||
}
|
||||
|
||||
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class UpdateReq {
|
||||
|
||||
private String requestPath;
|
||||
private String responsePath;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class TrainingDataset {
|
||||
|
||||
@@ -6,6 +6,10 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
@@ -96,4 +100,20 @@ public class ModelTrainDetailService {
|
||||
|
||||
return transferDetailDto;
|
||||
}
|
||||
|
||||
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTrainMetricResult(uuid);
|
||||
}
|
||||
|
||||
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelValidationMetricResult(uuid);
|
||||
}
|
||||
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTestMetricResult(uuid);
|
||||
}
|
||||
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -29,6 +31,7 @@ public class ModelTrainMngService {
|
||||
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final HyperParamCoreService hyperParamCoreService;
|
||||
private final TmpDatasetService tmpDatasetService;
|
||||
|
||||
/**
|
||||
* 모델학습 조회
|
||||
@@ -57,13 +60,13 @@ public class ModelTrainMngService {
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public void createModelTrain(ModelTrainMngDto.AddReq req) {
|
||||
public UUID createModelTrain(ModelTrainMngDto.AddReq req) {
|
||||
HyperParam hyperParam = req.getHyperParam();
|
||||
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
|
||||
|
||||
// 전이 학습은 모델 선택 필수
|
||||
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
|
||||
if (req.getBeforeModelId() != null) {
|
||||
if (TrainType.TRANSFER.getId().equals(req.getTrainType())) {
|
||||
if (req.getBeforeModelId() == null) {
|
||||
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
|
||||
}
|
||||
}
|
||||
@@ -76,7 +79,10 @@ public class ModelTrainMngService {
|
||||
}
|
||||
|
||||
// 모델학습 테이블 저장
|
||||
Long modelId = modelTrainMngCoreService.saveModel(req);
|
||||
ModelTrainMngDto.Basic modelDto = modelTrainMngCoreService.saveModel(req);
|
||||
|
||||
Long modelId = modelDto.getId();
|
||||
UUID modelUuid = modelDto.getUuid();
|
||||
|
||||
// 모델학습 데이터셋 저장
|
||||
modelTrainMngCoreService.saveModelDataset(modelId, req);
|
||||
@@ -87,6 +93,24 @@ public class ModelTrainMngService {
|
||||
|
||||
// 모델 config 저장
|
||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
|
||||
|
||||
List<String> uids =
|
||||
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return modelUuid;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,4 +136,8 @@ public class ModelTrainMngService {
|
||||
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
|
||||
}
|
||||
}
|
||||
|
||||
public Long findModelStep1InProgressCnt() {
|
||||
return modelTrainMngCoreService.findModelStep1InProgressCnt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.kamco.cd.training.model.service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
// 환경에 맞게 yml로 빼는 걸 추천
|
||||
Path BASE = Paths.get(requestDir);
|
||||
Path tmp = BASE.resolve("tmp").resolve(uid);
|
||||
|
||||
// mkdir -p "$TMP"/train/{input1,input2,label} ...
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
Files.createDirectories(tmp.resolve(type).resolve(part));
|
||||
}
|
||||
}
|
||||
|
||||
for (String id : datasetUids) {
|
||||
Path srcRoot = BASE.resolve(id);
|
||||
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
|
||||
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
|
||||
if (!Files.isDirectory(srcDir)) continue;
|
||||
|
||||
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||
for (Path f : stream) {
|
||||
if (!Files.isRegularFile(f)) continue;
|
||||
|
||||
String dstName = id + "__" + f.getFileName();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||
|
||||
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
|
||||
if (Files.exists(dst)) continue;
|
||||
|
||||
// ln -s "$f" "$dst" 와 동일
|
||||
Files.createSymbolicLink(dst, f.toAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}", tmp);
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
@@ -242,4 +242,8 @@ public class DatasetCoreService
|
||||
|
||||
entity.setStatus(LearnDataRegister.COMPLETED.getId());
|
||||
}
|
||||
|
||||
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
|
||||
datasetObjRepository.insertDatasetValObj(objRegDto);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.common.dto.HyperParam;
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.common.exception.CustomApiException;
|
||||
import com.kamco.cd.training.common.utils.UserUtil;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
|
||||
import java.time.ZonedDateTime;
|
||||
@@ -17,6 +19,7 @@ import org.springframework.stereotype.Service;
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class HyperParamCoreService {
|
||||
|
||||
private final HyperParamRepository hyperParamRepository;
|
||||
private final UserUtil userUtil;
|
||||
|
||||
@@ -27,7 +30,7 @@ public class HyperParamCoreService {
|
||||
* @return 등록된 버전명
|
||||
*/
|
||||
public Basic createHyperParam(HyperParam createReq) {
|
||||
String firstVersion = getFirstHyperParamVersion();
|
||||
String firstVersion = getFirstHyperParamVersion(createReq.getModel());
|
||||
|
||||
ModelHyperParamEntity entity = new ModelHyperParamEntity();
|
||||
entity.setHyperVer(firstVersion);
|
||||
@@ -47,17 +50,17 @@ public class HyperParamCoreService {
|
||||
/**
|
||||
* 하이퍼파라미터 수정
|
||||
*
|
||||
* @param uuid uuid
|
||||
* @param uuid uuid
|
||||
* @param createReq 등록 요청
|
||||
* @return ver
|
||||
*/
|
||||
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
if (entity.getIsDefault()) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
applyHyperParam(entity, createReq);
|
||||
@@ -69,11 +72,112 @@ public class HyperParamCoreService {
|
||||
return entity.getHyperVer();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 삭제
|
||||
*
|
||||
* @param uuid
|
||||
*/
|
||||
public void deleteHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
// }
|
||||
|
||||
//디폴트면 삭제불가
|
||||
if (entity.getIsDefault()) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
entity.setDelYn(true);
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 최적화 설정값 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.getHyperparamByType(model)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 상세 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 목록 조회
|
||||
*
|
||||
* @param model
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
|
||||
return hyperParamRepository.findByHyperVerList(model, req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 버전 조회
|
||||
*
|
||||
* @param model 모델 타입
|
||||
* @return ver
|
||||
*/
|
||||
public String getFirstHyperParamVersion(ModelType model) {
|
||||
return hyperParamRepository
|
||||
.findHyperParamVerByModelType(model)
|
||||
.map(ModelHyperParamEntity::getHyperVer)
|
||||
.map(ver -> increase(ver, model))
|
||||
.orElse(model.name() + "_000001");
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼 파라미터의 버전을 증가시킨다.
|
||||
*
|
||||
* @param hyperVer 현재 버전
|
||||
* @param modelType 모델 타입
|
||||
* @return 증가된 버전
|
||||
*/
|
||||
private String increase(String hyperVer, ModelType modelType) {
|
||||
String prefix = modelType.name() + "_";
|
||||
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
|
||||
return prefix + String.format("%06d", num + 1);
|
||||
}
|
||||
|
||||
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
|
||||
|
||||
ModelType model = src.getModel();
|
||||
|
||||
// 하드코딩 모델별로 다른경우 250212 bbn 하드코딩
|
||||
if (model == ModelType.G3) {
|
||||
entity.setCropSize("512,512");
|
||||
} else {
|
||||
entity.setCropSize("256,256");
|
||||
}
|
||||
// entity.setCropSize(src.getCropSize());
|
||||
|
||||
// Important
|
||||
entity.setModelType(model); // 20250212 modeltype추가
|
||||
entity.setBackbone(src.getBackbone());
|
||||
entity.setInputSize(src.getInputSize());
|
||||
entity.setCropSize(src.getCropSize());
|
||||
entity.setBatchSize(src.getBatchSize());
|
||||
|
||||
// Data
|
||||
@@ -111,78 +215,4 @@ public class HyperParamCoreService {
|
||||
entity.setMemo(src.getMemo());
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 삭제
|
||||
*
|
||||
* @param uuid
|
||||
*/
|
||||
public void deleteHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
entity.setDelYn(true);
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 최적화 설정값 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public HyperParamDto.Basic getInitHyperParam() {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByHyperVer("HPs_0001")
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 상세 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 목록 조회
|
||||
*
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
|
||||
return hyperParamRepository.findByHyperVerList(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 버전 조회
|
||||
*
|
||||
* @return ver
|
||||
*/
|
||||
public String getFirstHyperParamVersion() {
|
||||
return hyperParamRepository
|
||||
.findHyperParamVer()
|
||||
.map(ModelHyperParamEntity::getHyperVer)
|
||||
.map(this::increase)
|
||||
.orElse("HPs_0001");
|
||||
}
|
||||
|
||||
private String increase(String hyperVer) {
|
||||
String prefix = "HPs_";
|
||||
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
|
||||
return prefix + String.format("%04d", num + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTestMetricsJobCoreService {
|
||||
|
||||
private final ModelTestMetricsJobRepository modelTestMetricsJobRepository;
|
||||
|
||||
@Transactional
|
||||
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||
modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
|
||||
}
|
||||
|
||||
// Test 로직 시작
|
||||
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
|
||||
}
|
||||
|
||||
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,10 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
@@ -77,4 +81,20 @@ public class ModelTrainDetailCoreService {
|
||||
public ModelConfigDto.Basic findModelConfig(Long modelId) {
|
||||
return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null);
|
||||
}
|
||||
|
||||
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainMetricResult(uuid);
|
||||
}
|
||||
|
||||
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
|
||||
return modelDetailRepository.getModelValidationMetricResult(uuid);
|
||||
}
|
||||
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
return modelDetailRepository.getModelTestMetricResult(uuid);
|
||||
}
|
||||
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class ModelTrainJobCoreService {
|
||||
|
||||
private final ModelTrainJobRepository modelTrainJobRepository;
|
||||
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
return modelTrainJobRepository.findMaxAttemptNo(modelId);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
|
||||
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
public Optional<ModelTrainJobDto> findById(Long jobId) {
|
||||
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
|
||||
}
|
||||
|
||||
/** QUEUED Job 생성 */
|
||||
@Transactional
|
||||
public Long createQueuedJob(
|
||||
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
|
||||
|
||||
ModelTrainJobEntity job = new ModelTrainJobEntity();
|
||||
job.setModelId(modelId);
|
||||
job.setAttemptNo(attemptNo);
|
||||
job.setStatusCd("QUEUED");
|
||||
job.setParamsJson(paramsJson);
|
||||
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
|
||||
|
||||
modelTrainJobRepository.save(job);
|
||||
return job.getId();
|
||||
}
|
||||
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(
|
||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("RUNNING");
|
||||
job.setContainerName(containerName);
|
||||
job.setLogPath(logPath);
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
|
||||
if (totalEpoch != null) {
|
||||
job.setTotalEpoch(totalEpoch);
|
||||
}
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
@Transactional
|
||||
public void markSuccess(Long jobId, int exitCode) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("SUCCESS");
|
||||
job.setExitCode(exitCode);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 실패 처리 */
|
||||
@Transactional
|
||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("FAILED");
|
||||
job.setExitCode(exitCode);
|
||||
job.setErrorMessage(errorMessage);
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 취소 처리 */
|
||||
@Transactional
|
||||
public void markCanceled(Long jobId) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainMetricsJobCoreService {
|
||||
|
||||
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
|
||||
|
||||
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
|
||||
}
|
||||
|
||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||
modelTrainMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
|
||||
}
|
||||
|
||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
|
||||
}
|
||||
}
|
||||
@@ -23,17 +23,22 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainMngCoreService {
|
||||
|
||||
private final ModelMngRepository modelMngRepository;
|
||||
private final ModelDatasetRepository modelDatasetRepository;
|
||||
private final ModelDatasetMappRepository modelDatasetMapRepository;
|
||||
@@ -60,9 +65,9 @@ public class ModelTrainMngCoreService {
|
||||
*/
|
||||
public void deleteModel(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
entity.setDelYn(true);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
@@ -74,17 +79,19 @@ public class ModelTrainMngCoreService {
|
||||
* @param addReq
|
||||
* @return
|
||||
*/
|
||||
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
|
||||
public ModelTrainMngDto.Basic saveModel(ModelTrainMngDto.AddReq addReq) {
|
||||
ModelMasterEntity entity = new ModelMasterEntity();
|
||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||
|
||||
// 최적화 파라미터는 HPs_0001 사용
|
||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||
hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
|
||||
} else {
|
||||
hyperParamEntity =
|
||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||
}
|
||||
|
||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||
@@ -92,8 +99,8 @@ public class ModelTrainMngCoreService {
|
||||
}
|
||||
|
||||
String modelVer =
|
||||
String.join(
|
||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||
String.join(
|
||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||
entity.setModelVer(modelVer);
|
||||
entity.setHyperParamId(hyperParamEntity.getId());
|
||||
entity.setModelNo(addReq.getModelNo());
|
||||
@@ -108,18 +115,24 @@ public class ModelTrainMngCoreService {
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
} else {
|
||||
entity.setStatusCd(TrainStatusType.READY.getId());
|
||||
entity.setStep1State(TrainStatusType.READY.getId());
|
||||
}
|
||||
|
||||
entity.setCreatedUid(userUtil.getId());
|
||||
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
||||
return resultEntity.getId();
|
||||
|
||||
ModelTrainMngDto.Basic result = new ModelTrainMngDto.Basic();
|
||||
result.setId(resultEntity.getId());
|
||||
result.setUuid(resultEntity.getUuid());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* data set 저장
|
||||
*
|
||||
* @param modelId 저장한 모델 학습 id
|
||||
* @param addReq 요청 파라미터
|
||||
* @param addReq 요청 파라미터
|
||||
*/
|
||||
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||
@@ -144,10 +157,30 @@ public class ModelTrainMngCoreService {
|
||||
modelDatasetRepository.save(datasetEntity);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습모델 수정
|
||||
*
|
||||
* @param modelId
|
||||
* @param req
|
||||
*/
|
||||
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||
entity.setRequestPath(req.getRequestPath());
|
||||
}
|
||||
|
||||
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
|
||||
entity.setRequestPath(req.getResponsePath());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델 데이터셋 mapping 테이블 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param modelId 모델학습 id
|
||||
* @param datasetList 선택한 data set
|
||||
*/
|
||||
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||
@@ -164,7 +197,7 @@ public class ModelTrainMngCoreService {
|
||||
* 모델학습 config 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param req 요청 파라미터
|
||||
* @param req 요청 파라미터
|
||||
* @return
|
||||
*/
|
||||
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||
@@ -184,7 +217,7 @@ public class ModelTrainMngCoreService {
|
||||
/**
|
||||
* 데이터셋 매핑 생성
|
||||
*
|
||||
* @param modelUid 모델 UID
|
||||
* @param modelUid 모델 UID
|
||||
* @param datasetIds 데이터셋 ID 목록
|
||||
*/
|
||||
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
||||
@@ -206,13 +239,27 @@ public class ModelTrainMngCoreService {
|
||||
public ModelMasterEntity findByUuid(UUID uuid) {
|
||||
try {
|
||||
return modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* uuid로 model id 조회
|
||||
*
|
||||
* @param uuid
|
||||
* @return
|
||||
*/
|
||||
public Long findModelIdByUuid(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.getId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델학습 아이디로 config정보 조회
|
||||
*
|
||||
@@ -222,8 +269,8 @@ public class ModelTrainMngCoreService {
|
||||
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
||||
ModelMasterEntity modelEntity = findByUuid(uuid);
|
||||
return modelConfigRepository
|
||||
.findModelConfigByModelId(modelEntity.getId())
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
.findModelConfigByModelId(modelEntity.getId())
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -245,4 +292,243 @@ public class ModelTrainMngCoreService {
|
||||
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
|
||||
return datasetRepository.getDatasetSelectG2G3List(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* 모델관리 조회
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(id)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
||||
*/
|
||||
@Transactional
|
||||
public void markInProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
master.setCurrentAttemptId(jobId);
|
||||
|
||||
// 필요하면 시작시간도 여기서 찍어줌
|
||||
}
|
||||
|
||||
/**
|
||||
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
||||
*/
|
||||
@Transactional
|
||||
public void clearLastError(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
||||
*/
|
||||
@Transactional
|
||||
public void markStopped(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* 완료 처리(옵션) - Worker가 성공 시 호출
|
||||
*/
|
||||
@Transactional
|
||||
public void markCompleted(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||
master.setLastError(errorMessage);
|
||||
master.setUpdatedUid(userUtil.getId());
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void markSuccess(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
// 모델 상태 완료 처리
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
|
||||
// (선택) 마지막 에러 메시지 비우기
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 실행에 필요한 파라미터 조회
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return modelMngRepository.findTrainRunRequest(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* step1 진행중 처리
|
||||
*
|
||||
* @param modelId
|
||||
* @param jobId
|
||||
*/
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step2 진행중 처리
|
||||
*
|
||||
* @param modelId
|
||||
*/
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setCurrentAttemptId(jobId);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step1 완료처리
|
||||
*
|
||||
* @param modelId
|
||||
*/
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep1Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step2 완료처리
|
||||
*
|
||||
* @param modelId
|
||||
*/
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep2Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2EndDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
}
|
||||
|
||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setBestEpoch(epoch);
|
||||
}
|
||||
|
||||
/**
|
||||
* 데이터셋 uid 조회
|
||||
*
|
||||
* @param datasetIds
|
||||
* @return
|
||||
*/
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return datasetRepository.findDatasetUid(datasetIds);
|
||||
}
|
||||
|
||||
public List<Long> findModelDatasetMapp(Long modelId) {
|
||||
List<Long> datasetUids = new ArrayList<>();
|
||||
List<ModelDatasetMappEntity> entities = modelDatasetMapRepository.findByModelUid(modelId);
|
||||
for (ModelDatasetMappEntity entity : entities) {
|
||||
datasetUids.add(entity.getDatasetUid());
|
||||
}
|
||||
|
||||
return datasetUids;
|
||||
}
|
||||
|
||||
public Long findModelStep1InProgressCnt() {
|
||||
return modelMngRepository.findModelStep1InProgressCnt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.dataset.dto.DatasetObjDto.Basic;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.UUID;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.hibernate.annotations.ColumnDefault;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.locationtech.jts.geom.Geometry;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_dataset_val_obj")
|
||||
public class DatasetValObjEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "obj_id", nullable = false)
|
||||
private Long objId;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "dataset_uid", nullable = false)
|
||||
private Long datasetUid;
|
||||
|
||||
@Column(name = "target_yyyy")
|
||||
private Integer targetYyyy;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "target_class_cd")
|
||||
private String targetClassCd;
|
||||
|
||||
@Column(name = "compare_yyyy")
|
||||
private Integer compareYyyy;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "compare_class_cd")
|
||||
private String compareClassCd;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "target_path")
|
||||
private String targetPath;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "compare_path")
|
||||
private String comparePath;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "label_path")
|
||||
private String labelPath;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "geojson_path")
|
||||
private String geojsonPath;
|
||||
|
||||
@Size(max = 255)
|
||||
@Column(name = "map_sheet_num")
|
||||
private String mapSheetNum;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "created_dttm")
|
||||
private ZonedDateTime createdDttm;
|
||||
|
||||
@Column(name = "created_uid")
|
||||
private Long createdUid;
|
||||
|
||||
@ColumnDefault("false")
|
||||
@Column(name = "deleted")
|
||||
private Boolean deleted;
|
||||
|
||||
@Column(name = "uuid")
|
||||
private UUID uuid;
|
||||
|
||||
@Size(max = 32)
|
||||
@Column(name = "uid")
|
||||
private String uid;
|
||||
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
@Column(name = "geo_jsonb", columnDefinition = "jsonb")
|
||||
private String geoJsonb;
|
||||
|
||||
@Column(name = "file_name")
|
||||
private String fileName;
|
||||
|
||||
@Column(name = "geom", columnDefinition = "geometry")
|
||||
private Geometry geom;
|
||||
|
||||
public Basic toDto() {
|
||||
return new Basic(
|
||||
this.objId,
|
||||
this.datasetUid,
|
||||
this.targetYyyy,
|
||||
this.targetClassCd,
|
||||
this.compareYyyy,
|
||||
this.compareClassCd,
|
||||
this.targetPath,
|
||||
this.comparePath,
|
||||
this.labelPath,
|
||||
this.geojsonPath,
|
||||
this.mapSheetNum,
|
||||
this.createdDttm,
|
||||
this.createdUid,
|
||||
this.deleted,
|
||||
this.uuid,
|
||||
this.geoJsonb);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import jakarta.persistence.*;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
@@ -311,8 +312,17 @@ public class ModelHyperParamEntity {
|
||||
@Column(name = "m3_use_cnt")
|
||||
private Long m3UseCnt = 0L;
|
||||
|
||||
@Column(name = "model_type")
|
||||
@Enumerated(EnumType.STRING)
|
||||
private ModelType modelType;
|
||||
|
||||
|
||||
@Column(name = "default_param")
|
||||
private Boolean isDefault = false;
|
||||
|
||||
public HyperParamDto.Basic toDto() {
|
||||
return new HyperParamDto.Basic(
|
||||
this.modelType,
|
||||
this.uuid,
|
||||
this.hyperVer,
|
||||
this.createdDttm,
|
||||
@@ -385,6 +395,8 @@ public class ModelHyperParamEntity {
|
||||
// -------------------------
|
||||
this.gpuCnt,
|
||||
this.gpuIds,
|
||||
this.masterPort);
|
||||
this.masterPort
|
||||
, this.isDefault
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +91,27 @@ public class ModelMasterEntity {
|
||||
@Column(name = "before_model_id")
|
||||
private Long beforeModelId;
|
||||
|
||||
@Column(name = "step1_metric_save_yn")
|
||||
private Boolean step1MetricSaveYn;
|
||||
|
||||
@Column(name = "step2_metric_save_yn")
|
||||
private Boolean step2MetricSaveYn;
|
||||
|
||||
@Column(name = "current_attempt_id")
|
||||
private Long currentAttemptId;
|
||||
|
||||
@Column(name = "last_error")
|
||||
private String lastError;
|
||||
|
||||
@Column(name = "best_epoch")
|
||||
private Integer bestEpoch;
|
||||
|
||||
@Column(name = "request_path")
|
||||
private String requestPath;
|
||||
|
||||
@Column(name = "response_path")
|
||||
private String responsePath;
|
||||
|
||||
public ModelTrainMngDto.Basic toDto() {
|
||||
return new ModelTrainMngDto.Basic(
|
||||
this.id,
|
||||
@@ -105,6 +126,7 @@ public class ModelMasterEntity {
|
||||
this.step2State,
|
||||
this.statusCd,
|
||||
this.trainType,
|
||||
this.modelNo);
|
||||
this.modelNo,
|
||||
this.currentAttemptId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ import org.hibernate.annotations.ColumnDefault;
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_matrics_test")
|
||||
public class ModelMatricsTestEntity {
|
||||
@Table(name = "tb_model_metrics_test")
|
||||
public class ModelMetricsTestEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@@ -45,9 +45,6 @@ public class ModelMatricsTestEntity {
|
||||
@Column(name = "fn")
|
||||
private Long fn;
|
||||
|
||||
@Column(name = "tn")
|
||||
private Long tn;
|
||||
|
||||
@Column(name = "precisions")
|
||||
private Float precisions;
|
||||
|
||||
@@ -63,8 +60,11 @@ public class ModelMatricsTestEntity {
|
||||
@Column(name = "iou")
|
||||
private Float iou;
|
||||
|
||||
@Column(name = "processed_images")
|
||||
private Long processedImages;
|
||||
@Column(name = "detection_count")
|
||||
private Long detectionCount;
|
||||
|
||||
@Column(name = "gt_count")
|
||||
private Long gtCount;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "created_dttm")
|
||||
@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_matrics_train")
|
||||
public class ModelMatricsTrainEntity {
|
||||
@Table(name = "tb_model_metrics_train")
|
||||
public class ModelMetricsTrainEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_matrics_validation")
|
||||
public class ModelMatricsValidationEntity {
|
||||
@Table(name = "tb_model_metrics_validation")
|
||||
public class ModelMetricsValidationEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@@ -0,0 +1,103 @@
|
||||
package com.kamco.cd.training.postgres.entity;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Table;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.hibernate.annotations.ColumnDefault;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@Entity
|
||||
@Table(name = "tb_model_train_job")
|
||||
public class ModelTrainJobEntity {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "id", nullable = false)
|
||||
private Long id;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "model_id", nullable = false)
|
||||
private Long modelId;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "attempt_no", nullable = false)
|
||||
private Integer attemptNo;
|
||||
|
||||
@Size(max = 30)
|
||||
@NotNull
|
||||
@Column(name = "status_cd", nullable = false, length = 30)
|
||||
private String statusCd;
|
||||
|
||||
@NotNull
|
||||
@Column(name = "params_json", nullable = false)
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
private Map<String, Object> paramsJson;
|
||||
|
||||
@Size(max = 200)
|
||||
@Column(name = "container_name", length = 200)
|
||||
private String containerName;
|
||||
|
||||
@Size(max = 500)
|
||||
@Column(name = "log_path", length = 500)
|
||||
private String logPath;
|
||||
|
||||
@Column(name = "exit_code")
|
||||
private Integer exitCode;
|
||||
|
||||
@Size(max = 2000)
|
||||
@Column(name = "error_message", length = 2000)
|
||||
private String errorMessage;
|
||||
|
||||
@ColumnDefault("now()")
|
||||
@Column(name = "queued_dttm")
|
||||
private ZonedDateTime queuedDttm;
|
||||
|
||||
@Column(name = "started_dttm")
|
||||
private ZonedDateTime startedDttm;
|
||||
|
||||
@Column(name = "finished_dttm")
|
||||
private ZonedDateTime finishedDttm;
|
||||
|
||||
@Column(name = "locked_dttm")
|
||||
private ZonedDateTime lockedDttm;
|
||||
|
||||
@Size(max = 100)
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
|
||||
@Column(name = "total_epoch")
|
||||
private Integer totalEpoch;
|
||||
|
||||
@Column(name = "current_epoch")
|
||||
private Integer currentEpoch;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
this.modelId,
|
||||
this.attemptNo,
|
||||
this.statusCd,
|
||||
this.exitCode,
|
||||
this.errorMessage,
|
||||
this.containerName,
|
||||
this.paramsJson,
|
||||
this.queuedDttm,
|
||||
this.startedDttm,
|
||||
this.finishedDttm,
|
||||
this.totalEpoch,
|
||||
this.currentEpoch);
|
||||
}
|
||||
}
|
||||
@@ -22,4 +22,6 @@ public interface DatasetObjRepositoryCustom {
|
||||
String getFilePathByUUIDPathType(UUID uuid, String pathType);
|
||||
|
||||
void insertDatasetTestObj(DatasetObjRegDto objRegDto);
|
||||
|
||||
void insertDatasetValObj(DatasetObjRegDto objRegDto);
|
||||
}
|
||||
|
||||
@@ -97,6 +97,49 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
String json;
|
||||
String geometryJson;
|
||||
try {
|
||||
json = objectMapper.writeValueAsString(objRegDto.getGeojson());
|
||||
geometryJson =
|
||||
objectMapper.writeValueAsString(
|
||||
objRegDto.getGeojson().path("features").get(0).path("geometry"));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
try {
|
||||
em.createNativeQuery(
|
||||
"""
|
||||
insert into tb_dataset_val_obj
|
||||
(dataset_uid, target_yyyy, target_class_cd,
|
||||
compare_yyyy, compare_class_cd,
|
||||
target_path, compare_path, label_path, geo_jsonb, map_sheet_num, file_name, geom, geojson_path)
|
||||
values
|
||||
(?, ?, ?, ?, ?, ?, ?, ?, cast(? as jsonb), ?, ?, ST_SetSRID(ST_GeomFromGeoJSON(?), 5186), ?)
|
||||
""")
|
||||
.setParameter(1, objRegDto.getDatasetUid())
|
||||
.setParameter(2, objRegDto.getTargetYyyy())
|
||||
.setParameter(3, objRegDto.getTargetClassCd())
|
||||
.setParameter(4, objRegDto.getCompareYyyy())
|
||||
.setParameter(5, objRegDto.getCompareClassCd())
|
||||
.setParameter(6, objRegDto.getTargetPath())
|
||||
.setParameter(7, objRegDto.getComparePath())
|
||||
.setParameter(8, objRegDto.getLabelPath())
|
||||
.setParameter(9, json)
|
||||
.setParameter(10, objRegDto.getMapSheetNum())
|
||||
.setParameter(11, objRegDto.getFileName())
|
||||
.setParameter(12, geometryJson)
|
||||
.setParameter(13, objRegDto.getGeojsonPath())
|
||||
.executeUpdate();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page<DatasetObjEntity> searchDatasetObjectList(SearchReq searchReq) {
|
||||
Pageable pageable = searchReq.toPageable();
|
||||
|
||||
@@ -22,4 +22,6 @@ public interface DatasetRepositoryCustom {
|
||||
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
|
||||
|
||||
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
||||
|
||||
List<String> findDatasetUid(List<Long> datasetIds);
|
||||
}
|
||||
|
||||
@@ -242,4 +242,9 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
|
||||
.where(dataset.uid.eq(mngRegDto.getUid()))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.kamco.cd.training.postgres.repository.hyperparam;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
@@ -13,11 +15,22 @@ public interface HyperParamRepositoryCustom {
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
Optional<ModelHyperParamEntity> findHyperParamVer();
|
||||
|
||||
/**
|
||||
* 모델 타입별 마지막 버전 조회
|
||||
*
|
||||
* @param modelType 모델 타입
|
||||
* @return
|
||||
*/
|
||||
Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType);
|
||||
|
||||
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
|
||||
|
||||
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
|
||||
|
||||
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
|
||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||
|
||||
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package com.kamco.cd.training.postgres.repository.hyperparam;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
@@ -41,6 +43,23 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
.fetchOne());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
|
||||
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(
|
||||
modelHyperParamEntity
|
||||
.delYn
|
||||
.isFalse()
|
||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
||||
.limit(1)
|
||||
.fetchOne());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
|
||||
|
||||
@@ -68,10 +87,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
|
||||
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
|
||||
Pageable pageable = req.toPageable();
|
||||
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||
|
||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||
@@ -161,4 +181,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
|
||||
return new PageImpl<>(content, pageable, totalCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.fetchOne());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,4 +6,5 @@ import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public interface ModelDatasetMappRepository
|
||||
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId> {}
|
||||
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId>,
|
||||
ModelDatasetMappRepositoryCustom {}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelDatasetMappRepositoryCustom {
|
||||
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositoryCustom {
|
||||
|
||||
private final JPAQueryFactory queryFactory;
|
||||
|
||||
@Override
|
||||
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
|
||||
queryFactory
|
||||
.select(modelDatasetMappEntity)
|
||||
.from(modelDatasetMappEntity)
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId));
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,10 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import java.util.List;
|
||||
@@ -22,4 +26,12 @@ public interface ModelDetailRepositoryCustom {
|
||||
List<MappingDataset> getByModelMappingDataset(UUID uuid);
|
||||
|
||||
ModelMasterEntity findByModelByUUID(UUID uuid);
|
||||
|
||||
List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid);
|
||||
|
||||
List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid);
|
||||
|
||||
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
|
||||
|
||||
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,17 @@ import static com.kamco.cd.training.postgres.entity.QModelDatasetEntity.modelDat
|
||||
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
|
||||
@@ -20,8 +27,10 @@ import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Slf4j
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
@@ -154,4 +163,110 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
.where(modelMasterEntity.uuid.eq(uuid))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainMetrics.class,
|
||||
modelMetricsTrainEntity.epoch,
|
||||
modelMetricsTrainEntity.iteration,
|
||||
modelMetricsTrainEntity.loss,
|
||||
modelMetricsTrainEntity.lr,
|
||||
modelMetricsTrainEntity.durationTime))
|
||||
.from(modelMetricsTrainEntity)
|
||||
.where(modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelValidationMetrics.class,
|
||||
modelMetricsValidationEntity.epoch,
|
||||
modelMetricsValidationEntity.aAcc,
|
||||
modelMetricsValidationEntity.mFscore,
|
||||
modelMetricsValidationEntity.mPrecision,
|
||||
modelMetricsValidationEntity.mRecall,
|
||||
modelMetricsValidationEntity.mIou,
|
||||
modelMetricsValidationEntity.mAcc,
|
||||
modelMetricsValidationEntity.changedFscore,
|
||||
modelMetricsValidationEntity.changedPrecision,
|
||||
modelMetricsValidationEntity.changedRecall,
|
||||
modelMetricsValidationEntity.unchangedFscore,
|
||||
modelMetricsValidationEntity.unchangedPrecision,
|
||||
modelMetricsValidationEntity.unchangedRecall))
|
||||
.from(modelMetricsValidationEntity)
|
||||
.where(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.getId()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTestMetrics.class,
|
||||
modelMetricsTestEntity.model1,
|
||||
modelMetricsTestEntity.tp,
|
||||
modelMetricsTestEntity.fp,
|
||||
modelMetricsTestEntity.fn,
|
||||
modelMetricsTestEntity.precisions,
|
||||
modelMetricsTestEntity.recall,
|
||||
modelMetricsTestEntity.f1Score,
|
||||
modelMetricsTestEntity.accuracy,
|
||||
modelMetricsTestEntity.iou,
|
||||
modelMetricsTestEntity.detectionCount,
|
||||
modelMetricsTestEntity.gtCount))
|
||||
.from(modelMetricsTestEntity)
|
||||
.where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
|
||||
if (modelMasterEntity == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelBestEpoch.class,
|
||||
modelMetricsTrainEntity.epoch,
|
||||
modelMetricsTrainEntity.loss,
|
||||
modelMetricsValidationEntity.mFscore,
|
||||
modelMetricsValidationEntity.mPrecision,
|
||||
modelMetricsValidationEntity.mRecall,
|
||||
modelMetricsValidationEntity.mIou,
|
||||
modelMetricsValidationEntity.mAcc))
|
||||
.from(modelMetricsTrainEntity)
|
||||
.leftJoin(modelMetricsValidationEntity)
|
||||
.on(
|
||||
modelMetricsTrainEntity.model.eq(modelMetricsValidationEntity.model),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMetricsValidationEntity.epoch))
|
||||
.where(
|
||||
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()),
|
||||
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Page;
|
||||
@@ -19,4 +20,8 @@ public interface ModelMngRepositoryCustom {
|
||||
Optional<ModelMasterEntity> findByUuid(UUID uuid);
|
||||
|
||||
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
|
||||
|
||||
TrainRunRequest findTrainRunRequest(Long modelId);
|
||||
|
||||
Long findModelStep1InProgressCnt();
|
||||
}
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.Expressions;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -40,6 +46,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
builder.and(modelMasterEntity.modelNo.eq(req.getModelNo()));
|
||||
}
|
||||
|
||||
builder.and(modelMasterEntity.delYn.isFalse());
|
||||
|
||||
List<ModelMasterEntity> content =
|
||||
queryFactory
|
||||
.selectFrom(modelMasterEntity)
|
||||
@@ -82,4 +90,71 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TrainRunRequest findTrainRunRequest(Long modelId) {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
TrainRunRequest.class,
|
||||
modelMasterEntity.requestPath, // datasetFolder
|
||||
modelMasterEntity.uuid, // outputFolder
|
||||
modelHyperParamEntity.inputSize,
|
||||
modelHyperParamEntity.cropSize,
|
||||
modelHyperParamEntity.batchSize,
|
||||
modelHyperParamEntity.gpuIds,
|
||||
modelHyperParamEntity.gpuCnt,
|
||||
modelHyperParamEntity.learningRate,
|
||||
modelHyperParamEntity.backbone,
|
||||
modelConfigEntity.epochCount,
|
||||
modelHyperParamEntity.trainNumWorkers,
|
||||
modelHyperParamEntity.valNumWorkers,
|
||||
modelHyperParamEntity.testNumWorkers,
|
||||
modelHyperParamEntity.trainShuffle,
|
||||
modelHyperParamEntity.trainPersistent,
|
||||
modelHyperParamEntity.valPersistent,
|
||||
modelHyperParamEntity.dropPathRate,
|
||||
modelHyperParamEntity.frozenStages,
|
||||
modelHyperParamEntity.neckPolicy,
|
||||
modelHyperParamEntity.classWeight,
|
||||
modelHyperParamEntity.decoderChannels,
|
||||
modelHyperParamEntity.weightDecay,
|
||||
modelHyperParamEntity.layerDecayRate,
|
||||
modelHyperParamEntity.ignoreIndex,
|
||||
modelHyperParamEntity.ddpFindUnusedParams,
|
||||
modelHyperParamEntity.numLayers,
|
||||
modelHyperParamEntity.metrics,
|
||||
modelHyperParamEntity.saveBest,
|
||||
modelHyperParamEntity.saveBestRule,
|
||||
modelHyperParamEntity.valInterval,
|
||||
modelHyperParamEntity.logInterval,
|
||||
modelHyperParamEntity.visInterval,
|
||||
modelHyperParamEntity.rotProb,
|
||||
modelHyperParamEntity.rotDegree,
|
||||
modelHyperParamEntity.flipProb,
|
||||
modelHyperParamEntity.exchangeProb,
|
||||
modelHyperParamEntity.brightnessDelta,
|
||||
modelHyperParamEntity.contrastRange,
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class),
|
||||
Expressions.nullExpression(String.class),
|
||||
modelHyperParamEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
.leftJoin(modelConfigEntity)
|
||||
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long findModelStep1InProgressCnt() {
|
||||
return queryFactory
|
||||
.select(modelMasterEntity.id.count())
|
||||
.from(modelMasterEntity)
|
||||
.where(modelMasterEntity.step1State.eq(TrainStatusType.IN_PROGRESS.getId()))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public interface ModelTestMetricsJobRepository
|
||||
extends JpaRepository<ModelMetricsTestEntity, Long>, ModelTestMetricsJobRepositoryCustom {}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelTestMetricsJobRepositoryCustom {
|
||||
|
||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||
|
||||
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
||||
|
||||
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
implements ModelTestMetricsJobRepositoryCustom {
|
||||
|
||||
private final JPAQueryFactory queryFactory;
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
public ModelTestMetricsJobRepositoryImpl(
|
||||
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
|
||||
super(ModelMetricsTestEntity.class);
|
||||
this.queryFactory = queryFactory;
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||
queryFactory
|
||||
.update(modelMasterEntity)
|
||||
.set(
|
||||
stepNo.equals("step1")
|
||||
? modelMasterEntity.step1MetricSaveYn
|
||||
: modelMasterEntity.step2MetricSaveYn,
|
||||
true)
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.execute();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||
modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()),
|
||||
modelMasterEntity
|
||||
.step2MetricSaveYn
|
||||
.isNull()
|
||||
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||
String sql =
|
||||
"""
|
||||
insert into tb_model_metrics_test
|
||||
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
|
||||
detection_count, gt_count
|
||||
)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""";
|
||||
|
||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
public interface ModelTrainJobRepository
|
||||
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import java.util.Optional;
|
||||
|
||||
public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import jakarta.persistence.EntityManager;
|
||||
import java.util.Optional;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
|
||||
|
||||
private final JPAQueryFactory queryFactory;
|
||||
|
||||
public ModelTrainJobRepositoryImpl(EntityManager em) {
|
||||
this.queryFactory = new JPAQueryFactory(em);
|
||||
}
|
||||
|
||||
/** modelId의 attempt_no 최대값. (없으면 0) */
|
||||
@Override
|
||||
public int findMaxAttemptNo(Long modelId) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
Integer max =
|
||||
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
|
||||
|
||||
return max != null ? max : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
|
||||
* 단순.
|
||||
*/
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
public interface ModelTrainMetricsJobRepository
|
||||
extends JpaRepository<ModelMetricsTrainEntity, Long>, ModelTrainMetricsJobRepositoryCustom {}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelTrainMetricsJobRepositoryCustom {
|
||||
|
||||
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
void insertModelMetricsTrain(List<Object[]> batchArgs);
|
||||
|
||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||
|
||||
void insertModelMetricsValidation(List<Object[]> batchArgs);
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
implements ModelTrainMetricsJobRepositoryCustom {
|
||||
|
||||
private final JPAQueryFactory queryFactory;
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
public ModelTrainMetricsJobRepositoryImpl(
|
||||
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
|
||||
super(ModelMetricsTrainEntity.class);
|
||||
this.queryFactory = queryFactory;
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
||||
modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()),
|
||||
modelMasterEntity
|
||||
.step1MetricSaveYn
|
||||
.isNull()
|
||||
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
|
||||
String sql =
|
||||
"""
|
||||
insert into tb_model_metrics_train
|
||||
(model_id, epoch, iteration, loss, lr, duration_time)
|
||||
values (?, ?, ?, ?, ?, ?)
|
||||
""";
|
||||
|
||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
|
||||
queryFactory
|
||||
.update(modelMasterEntity)
|
||||
.set(
|
||||
stepNo.equals("step1")
|
||||
? modelMasterEntity.step1MetricSaveYn
|
||||
: modelMasterEntity.step2MetricSaveYn,
|
||||
true)
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.execute();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||
String sql =
|
||||
"""
|
||||
insert into tb_model_metrics_validation
|
||||
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
|
||||
unchanged_fscore, unchanged_precision, unchanged_recall
|
||||
)
|
||||
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""";
|
||||
|
||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.TestJobService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
|
||||
@RequiredArgsConstructor
|
||||
@RestController
|
||||
@RequestMapping("/api/train")
|
||||
public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
private final TestJobService testJobService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "실행 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/run/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.enqueue(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "재실행 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/restart/{uuid}")
|
||||
public ApiResponseDto<String> restart(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
Long jobId = trainJobService.restart(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "이어하기 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/resume/{uuid}")
|
||||
public ApiResponseDto<String> resume(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
Long jobId = trainJobService.resume(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "학습 취소", description = "학습 취소 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "취소 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/cancel/{uuid}")
|
||||
public ApiResponseDto<String> cancel(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.cancel(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "test 실행", description = "test 실행 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "test 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/test/run/{epoch}/{uuid}")
|
||||
public ApiResponseDto<String> run(
|
||||
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
testJobService.enqueue(modelId, uuid, epoch);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "test 학습 취소", description = "학습 취소 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "취소 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/test/cancel/{uuid}")
|
||||
public ApiResponseDto<String> cancelTest(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
testJobService.cancel(modelId);
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "데이터셋 tmp 파일생성 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@PostMapping("/create-tmp/{uuid}")
|
||||
public ApiResponseDto<UUID> createTmpFile(
|
||||
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
|
||||
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EvalRunRequest {
|
||||
private String uuid;
|
||||
private int epoch; // best_changed_fscore_epoch_1.pth
|
||||
private Integer timeoutSeconds;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public class ModelTrainJobDto {
|
||||
|
||||
private Long id;
|
||||
private Long modelId;
|
||||
private Integer attemptNo;
|
||||
private String statusCd;
|
||||
private Integer exitCode;
|
||||
private String errorMessage;
|
||||
private String containerName;
|
||||
private Map<String, Object> paramsJson;
|
||||
private ZonedDateTime queuedDttm;
|
||||
private ZonedDateTime startedDttm;
|
||||
private ZonedDateTime finishedDttm;
|
||||
private Integer totalEpoch;
|
||||
private Integer currentEpoch;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
|
||||
public class ModelTrainJobQueuedEvent {
|
||||
|
||||
private final Long jobId;
|
||||
|
||||
public ModelTrainJobQueuedEvent(Long jobId) {
|
||||
this.jobId = jobId;
|
||||
}
|
||||
|
||||
public Long getJobId() {
|
||||
return jobId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
public class ModelTrainMetricsDto {
|
||||
|
||||
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ResponsePathDto {
|
||||
|
||||
private Long modelId;
|
||||
private String responsePath;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import java.util.UUID;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunRequest {
|
||||
|
||||
// ========================
|
||||
// 기본
|
||||
// ========================
|
||||
private String datasetFolder;
|
||||
private UUID outputFolder;
|
||||
private String inputSize;
|
||||
private String cropSize;
|
||||
private Integer batchSize;
|
||||
private String gpuIds;
|
||||
private Integer gpus;
|
||||
private Double learningRate;
|
||||
private String backbone;
|
||||
private Integer epochs;
|
||||
|
||||
// ========================
|
||||
// Data
|
||||
// ========================
|
||||
private Integer trainNumWorkers;
|
||||
private Integer valNumWorkers;
|
||||
private Integer testNumWorkers;
|
||||
private Boolean trainShuffle;
|
||||
private Boolean trainPersistent;
|
||||
private Boolean valPersistent;
|
||||
|
||||
// ========================
|
||||
// Model Architecture
|
||||
// ========================
|
||||
private Double dropPathRate;
|
||||
private Integer frozenStages;
|
||||
private String neckPolicy;
|
||||
private String classWeight;
|
||||
private String decoderChannels;
|
||||
|
||||
// ========================
|
||||
// Loss & Optimization
|
||||
// ========================
|
||||
private Double weightDecay;
|
||||
private Double layerDecayRate;
|
||||
private Integer ignoreIndex;
|
||||
private Boolean ddpFindUnusedParams;
|
||||
private Integer numLayers;
|
||||
|
||||
// ========================
|
||||
// Evaluation
|
||||
// ========================
|
||||
private String metrics;
|
||||
private String saveBest;
|
||||
private String saveBestRule;
|
||||
private Integer valInterval;
|
||||
private Integer logInterval;
|
||||
private Integer visInterval;
|
||||
|
||||
// ========================
|
||||
// Augmentation
|
||||
// ========================
|
||||
private Double rotProb;
|
||||
private String rotDegree;
|
||||
private Double flipProb;
|
||||
private Double exchangeProb;
|
||||
private Integer brightnessDelta;
|
||||
private String contrastRange;
|
||||
private String saturationRange;
|
||||
private Integer hueDelta;
|
||||
|
||||
// ========================
|
||||
// 실행 타임아웃
|
||||
// ========================
|
||||
private Integer timeoutSeconds;
|
||||
private String resumeFrom;
|
||||
|
||||
private UUID uuid;
|
||||
|
||||
public String getOutputFolder() {
|
||||
return String.valueOf(this.outputFolder);
|
||||
}
|
||||
|
||||
public String getUuid() {
|
||||
return String.valueOf(this.uuid);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
/** 학습 실행 결과 반환 객체 */
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainRunResult {
|
||||
|
||||
private String jobId;
|
||||
private String containerName;
|
||||
private int exitCode;
|
||||
private String status;
|
||||
private String logs;
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Log4j2
|
||||
@Service
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@Value("${train.docker.image}")
|
||||
private String image;
|
||||
|
||||
// 학습 요청 데이터가 위치한 호스트 디렉토리
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
// 컨테이너 이름 prefix
|
||||
@Value("${train.docker.containerPrefix}")
|
||||
private String containerPrefix;
|
||||
|
||||
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
|
||||
@Value("${train.docker.shmSize:16g}")
|
||||
private String shmSize;
|
||||
|
||||
// IPC host 사용 여부
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
|
||||
log.info("=== Docker Train Command ===");
|
||||
log.info("Container: {}", containerName);
|
||||
log.info("Command: {}", String.join(" ", cmd));
|
||||
log.info("================================");
|
||||
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder logBuilder = new StringBuilder();
|
||||
|
||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(
|
||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
|
||||
// 1) 로그 누적
|
||||
synchronized (logBuilder) {
|
||||
logBuilder.append(line).append('\n');
|
||||
}
|
||||
|
||||
// 2) epoch 감지 + DB 업데이트
|
||||
Matcher m = epochPattern.matcher(line);
|
||||
if (m.find()) {
|
||||
int currentEpoch = Integer.parseInt(m.group(1));
|
||||
int totalEpoch = Integer.parseInt(m.group(2));
|
||||
|
||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||
|
||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||
// currentEpoch, totalEpoch);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("logThread error: {}", e.toString());
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
// new Thread(
|
||||
// () -> {
|
||||
// try (BufferedReader br =
|
||||
// new BufferedReader(
|
||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
// String line;
|
||||
// while ((line = br.readLine()) != null) {
|
||||
// synchronized (log) {
|
||||
// log.append(line).append('\n');
|
||||
// }
|
||||
// }
|
||||
// } catch (Exception ignored) {
|
||||
// }
|
||||
// },
|
||||
// "train-log-" + containerName);
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
|
||||
int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
// docker run 프로세스도 같이 끊어야 readLine이 풀림
|
||||
p.destroy();
|
||||
if (!p.waitFor(2, TimeUnit.SECONDS)) {
|
||||
p.destroyForcibly();
|
||||
}
|
||||
killContainer(containerName);
|
||||
|
||||
String logs;
|
||||
synchronized (logBuilder) {
|
||||
logs = logBuilder.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(
|
||||
null, // jobId (없으면 null)
|
||||
containerName,
|
||||
-1,
|
||||
"TIMEOUT",
|
||||
logs);
|
||||
}
|
||||
|
||||
int exit = p.exitValue();
|
||||
|
||||
// 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
|
||||
logThread.join(500);
|
||||
|
||||
String logs;
|
||||
synchronized (logBuilder) {
|
||||
logs = logBuilder.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 docker run command
|
||||
*
|
||||
* @param containerName
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName + "-" + req.getUuid().substring(0, 8));
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
|
||||
// GPU 전체 사용
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
|
||||
// IPC host 사용 여부
|
||||
if (ipcHost) {
|
||||
c.add("--ipc=host");
|
||||
}
|
||||
|
||||
// 공유메모리 설정
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
// 메모리 관련 ulimit 설정
|
||||
c.add("--ulimit");
|
||||
c.add("memlock=-1");
|
||||
c.add("--ulimit");
|
||||
c.add("stack=67108864");
|
||||
|
||||
// 환경변수 설정
|
||||
c.add("-e");
|
||||
c.add("OPENCV_LOG_LEVEL=ERROR");
|
||||
c.add("-e");
|
||||
c.add("NCCL_DEBUG=INFO");
|
||||
c.add("-e");
|
||||
c.add("NCCL_IB_DISABLE=1");
|
||||
c.add("-e");
|
||||
c.add("NCCL_P2P_DISABLE=0");
|
||||
c.add("-e");
|
||||
c.add("NCCL_SOCKET_IFNAME=eth0");
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
// 표준입력 유지 (-it 대신 -i만 사용)
|
||||
c.add("-i");
|
||||
|
||||
// 사용할 이미지
|
||||
c.add(image);
|
||||
|
||||
// ===== 컨테이너 내부 실행 명령 =====
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||
|
||||
// ===== 기본 파라미터 =====
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
addArg(c, "--batch-size", req.getBatchSize());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
||||
// addArg(c, "--gpus", req.getGpus());
|
||||
addArg(c, "--lr", req.getLearningRate());
|
||||
addArg(c, "--backbone", req.getBackbone());
|
||||
addArg(c, "--epochs", req.getEpochs());
|
||||
|
||||
// ===== Data =====
|
||||
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
|
||||
addArg(c, "--val-num-workers", req.getValNumWorkers());
|
||||
addArg(c, "--test-num-workers", req.getTestNumWorkers());
|
||||
addArg(c, "--train-shuffle", req.getTrainShuffle());
|
||||
addArg(c, "--train-persistent", req.getTrainPersistent());
|
||||
addArg(c, "--val-persistent", req.getValPersistent());
|
||||
|
||||
// ===== Model Architecture =====
|
||||
addArg(c, "--drop-path-rate", req.getDropPathRate());
|
||||
addArg(c, "--frozen-stages", req.getFrozenStages());
|
||||
addArg(c, "--neck-policy", req.getNeckPolicy());
|
||||
addArg(c, "--class-weight", req.getClassWeight());
|
||||
addArg(c, "--decoder-channels", req.getDecoderChannels());
|
||||
|
||||
// ===== Loss & Optimization =====
|
||||
addArg(c, "--weight-decay", req.getWeightDecay());
|
||||
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
|
||||
addArg(c, "--ignore-index", req.getIgnoreIndex());
|
||||
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
|
||||
addArg(c, "--num-layers", req.getNumLayers());
|
||||
|
||||
// ===== Evaluation =====
|
||||
addArg(c, "--metrics", req.getMetrics());
|
||||
addArg(c, "--save-best", req.getSaveBest());
|
||||
addArg(c, "--save-best-rule", req.getSaveBestRule());
|
||||
addArg(c, "--val-interval", req.getValInterval());
|
||||
addArg(c, "--log-interval", req.getLogInterval());
|
||||
addArg(c, "--vis-interval", req.getVisInterval());
|
||||
|
||||
// ===== Augmentation =====
|
||||
addArg(c, "--rot-prob", req.getRotProb());
|
||||
addArg(c, "--rot-degree", req.getRotDegree());
|
||||
addArg(c, "--flip-prob", req.getFlipProb());
|
||||
addArg(c, "--exchange-prob", req.getExchangeProb());
|
||||
addArg(c, "--brightness-delta", req.getBrightnessDelta());
|
||||
addArg(c, "--contrast-range", req.getContrastRange());
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
addArg(c, "--resume-from", req.getResumeFrom());
|
||||
return c;
|
||||
}
|
||||
|
||||
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
|
||||
private void addArg(List<String> c, String key, Object value) {
|
||||
if (value == null) return;
|
||||
String s = String.valueOf(value).trim();
|
||||
if (s.isEmpty()) return;
|
||||
|
||||
c.add(key + "=" + s);
|
||||
}
|
||||
|
||||
/** 컨테이너 강제 종료 및 제거 */
|
||||
public void killContainer(String containerName) {
|
||||
try {
|
||||
new ProcessBuilder("docker", "rm", "-f", containerName)
|
||||
.redirectErrorStream(true)
|
||||
.start()
|
||||
.waitFor(10, TimeUnit.SECONDS);
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||
|
||||
log.info("=== Docker Test Command ===");
|
||||
log.info("Container: {}", containerName);
|
||||
log.info("Command: {}", String.join(" ", cmd));
|
||||
log.info("================================");
|
||||
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
StringBuilder log = new StringBuilder();
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(
|
||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
synchronized (log) {
|
||||
log.append(line).append('\n');
|
||||
}
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
});
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
|
||||
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
|
||||
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
|
||||
|
||||
if (!finished) {
|
||||
p.destroyForcibly();
|
||||
killContainer(containerName);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
|
||||
}
|
||||
|
||||
int exit = p.exitValue();
|
||||
logThread.join(500);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||
}
|
||||
|
||||
/**
|
||||
* 테스트 docker run command
|
||||
*
|
||||
* @param containerName
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
|
||||
|
||||
String uuid = req.getUuid();
|
||||
Integer epoch = req.getEpoch();
|
||||
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
|
||||
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
|
||||
|
||||
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth";
|
||||
|
||||
List<String> c = new ArrayList<>();
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
c.add("--name");
|
||||
c.add(containerName + "=" + req.getUuid().substring(0, 8));
|
||||
c.add("--rm");
|
||||
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
if (ipcHost) c.add("--ipc=host");
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
c.add(image);
|
||||
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
|
||||
|
||||
c.add("--dataset_dir");
|
||||
c.add("/data/" + uuid);
|
||||
|
||||
c.add("--model");
|
||||
c.add("/checkpoints/" + uuid + "/" + modelFile);
|
||||
|
||||
return c;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTestMetricsJobService {
|
||||
|
||||
private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService;
|
||||
|
||||
@Value("${spring.profiles.active}")
|
||||
private String profile;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
String model = record.get("model");
|
||||
long TP = Long.parseLong(record.get("TP"));
|
||||
long FP = Long.parseLong(record.get("FP"));
|
||||
long FN = Long.parseLong(record.get("FN"));
|
||||
float precision = Float.parseFloat(record.get("precision"));
|
||||
float recall = Float.parseFloat(record.get("recall"));
|
||||
float f1_score = Float.parseFloat(record.get("f1_score"));
|
||||
float accuracy = Float.parseFloat(record.get("accuracy"));
|
||||
float iou = Float.parseFloat(record.get("iou"));
|
||||
long detection_count = Long.parseLong(record.get("detection_count"));
|
||||
long gt_count = Long.parseLong(record.get("gt_count"));
|
||||
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelInfo.getModelId(),
|
||||
model,
|
||||
TP,
|
||||
FP,
|
||||
FN,
|
||||
precision,
|
||||
recall,
|
||||
f1_score,
|
||||
accuracy,
|
||||
iou,
|
||||
detection_count,
|
||||
gt_count
|
||||
});
|
||||
}
|
||||
|
||||
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ModelTrainMetricsJobService {
|
||||
|
||||
private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService;
|
||||
|
||||
@Value("${spring.profiles.active}")
|
||||
private String profile;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTrainValidMetricCsvFiles() {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기
|
||||
long iteration = Long.parseLong(record.get("Iteration"));
|
||||
double Loss = Double.parseDouble(record.get("Loss"));
|
||||
double LR = Double.parseDouble(record.get("LR"));
|
||||
float time = Float.parseFloat(record.get("Time"));
|
||||
|
||||
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
|
||||
|
||||
List<Object[]> batchArgs = new ArrayList<>();
|
||||
|
||||
for (CSVRecord record : parser) {
|
||||
|
||||
int epoch = Integer.parseInt(record.get("Epoch"));
|
||||
float aAcc = Float.parseFloat(record.get("aAcc"));
|
||||
float mFscore = Float.parseFloat(record.get("mFscore"));
|
||||
float mPrecision = Float.parseFloat(record.get("mPrecision"));
|
||||
float mRecall = Float.parseFloat(record.get("mRecall"));
|
||||
float mIoU = Float.parseFloat(record.get("mIoU"));
|
||||
float mAcc = Float.parseFloat(record.get("mAcc"));
|
||||
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
|
||||
float changed_precision = Float.parseFloat(record.get("changed_precision"));
|
||||
float changed_recall = Float.parseFloat(record.get("changed_recall"));
|
||||
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
|
||||
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
|
||||
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
|
||||
|
||||
batchArgs.add(
|
||||
new Object[] {
|
||||
modelInfo.getModelId(),
|
||||
epoch,
|
||||
aAcc,
|
||||
mFscore,
|
||||
mPrecision,
|
||||
mRecall,
|
||||
mIoU,
|
||||
mAcc,
|
||||
changed_fscore,
|
||||
changed_precision,
|
||||
changed_recall,
|
||||
unchanged_fscore,
|
||||
unchanged_precision,
|
||||
unchanged_recall
|
||||
});
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
||||
modelInfo.getModelId(), "step1");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class TestJobService {
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
|
||||
// 마스터 확인
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// best epoch 업데이트
|
||||
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
||||
|
||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||
params.put("jobType", "EVAL");
|
||||
params.put("uuid", String.valueOf(uuid));
|
||||
params.put("epoch", epoch);
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, params, ZonedDateTime.now());
|
||||
|
||||
// step2 시작으로 마킹
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
return jobId;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long jobId = master.getCurrentAttemptId();
|
||||
if (jobId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
var job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||
|
||||
String containerName = job.getContainerName();
|
||||
|
||||
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||
if (containerName != null && !containerName.isBlank()) {
|
||||
dockerTrainService.killContainer(containerName);
|
||||
}
|
||||
|
||||
// 2) 상태 업데이트 (항상 수행)
|
||||
modelTrainJobCoreService.markCanceled(jobId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.service.TmpDatasetService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional(readOnly = true)
|
||||
public class TrainJobService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final TmpDatasetService tmpDatasetService;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
public Long getModelIdByUuid(UUID uuid) {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
/** 실행 예약 (QUEUE 등록) */
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// 파라미터 조회
|
||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||
|
||||
if (trainRunRequest == null) {
|
||||
throw new IllegalArgumentException("Model not found: " + modelId);
|
||||
}
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||
paramsMap.put("jobType", "TRAIN");
|
||||
paramsMap.put("uuid", trainRunRequest.getUuid());
|
||||
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 재시작
|
||||
*
|
||||
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
|
||||
*/
|
||||
@Transactional
|
||||
public Long restart(Long modelId) {
|
||||
return createNextAttempt(modelId, ResumeMode.NONE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 이어하기
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long resume(Long modelId) {
|
||||
return createNextAttempt(modelId, ResumeMode.REQUIRE);
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단
|
||||
*
|
||||
* <p>- job 상태 CANCELED - master 상태 STOPPED
|
||||
*
|
||||
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
Long jobId = master.getCurrentAttemptId();
|
||||
if (jobId == null) {
|
||||
throw new IllegalStateException("실행중인 작업이 없습니다.");
|
||||
}
|
||||
|
||||
var job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalStateException("Job not found"));
|
||||
|
||||
String containerName = job.getContainerName();
|
||||
|
||||
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
|
||||
if (containerName != null && !containerName.isBlank()) {
|
||||
dockerTrainService.killContainer(containerName);
|
||||
}
|
||||
|
||||
// 2) 상태 업데이트 (항상 수행)
|
||||
modelTrainJobCoreService.markCanceled(jobId);
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
|
||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) {
|
||||
throw new IllegalStateException("이미 진행중입니다.");
|
||||
}
|
||||
|
||||
var lastJob =
|
||||
modelTrainJobCoreService
|
||||
.findLatestByModelId(modelId)
|
||||
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
// 이전 params_json 재사용 (재현성)
|
||||
Map<String, Object> params = lastJob.getParamsJson();
|
||||
if (params == null || params.isEmpty()) {
|
||||
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
|
||||
}
|
||||
|
||||
// mode에 따라 resume 옵션 주입/제거
|
||||
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
|
||||
|
||||
if (mode == ResumeMode.NONE) {
|
||||
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
|
||||
nextParams.remove("resumeFrom");
|
||||
nextParams.remove("resume");
|
||||
} else if (mode == ResumeMode.REQUIRE) {
|
||||
// 체크포인트 탐지해서 resumeFrom 세팅
|
||||
String resumeFrom = findResumeFromOrNull(nextParams);
|
||||
if (resumeFrom == null) {
|
||||
throw new IllegalStateException("이어하기 체크포인트가 없습니다.");
|
||||
}
|
||||
nextParams.put("resumeFrom", resumeFrom);
|
||||
nextParams.put("resume", true);
|
||||
}
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
|
||||
|
||||
modelTrainMngCoreService.clearLastError(modelId);
|
||||
modelTrainMngCoreService.markInProgress(modelId, jobId);
|
||||
|
||||
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
|
||||
return jobId;
|
||||
}
|
||||
|
||||
private enum ResumeMode {
|
||||
NONE, // 새로 시작
|
||||
REQUIRE // 이어하기
|
||||
}
|
||||
|
||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null) return null;
|
||||
|
||||
Object out = paramsJson.get("outputFolder");
|
||||
if (out == null) return null;
|
||||
|
||||
String outputFolder = String.valueOf(out).trim(); // uuid
|
||||
if (outputFolder.isEmpty()) return null;
|
||||
|
||||
// 호스트 기준 경로
|
||||
Path outDir = Paths.get(responseDir, outputFolder);
|
||||
|
||||
Path last = outDir.resolve("last_checkpoint");
|
||||
if (!Files.isRegularFile(last)) return null;
|
||||
|
||||
try {
|
||||
String ckptFile = Files.readString(last).trim(); // epoch_10.pth
|
||||
if (ckptFile.isEmpty()) return null;
|
||||
|
||||
Path ckptHost = outDir.resolve(ckptFile);
|
||||
if (!Files.isRegularFile(ckptHost)) return null;
|
||||
|
||||
// 컨테이너 경로 반환
|
||||
return "/checkpoints/" + outputFolder + "/" + ckptFile;
|
||||
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public UUID createTmpFile(UUID modelUuid) {
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
|
||||
|
||||
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
|
||||
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||
|
||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return modelUuid;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class TrainJobWorker {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Async
|
||||
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
|
||||
public void handle(ModelTrainJobQueuedEvent event) {
|
||||
|
||||
Long jobId = event.getJobId();
|
||||
|
||||
ModelTrainJobDto job =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
|
||||
return;
|
||||
}
|
||||
|
||||
Long modelId = job.getModelId();
|
||||
Map<String, Object> params = job.getParamsJson();
|
||||
|
||||
String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
|
||||
|
||||
boolean isEval = "EVAL".equals(jobType);
|
||||
|
||||
String containerName =
|
||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||
|
||||
Integer totalEpoch = null;
|
||||
if (params.containsKey("totalEpoch")) {
|
||||
if (params.get("totalEpoch") != null) {
|
||||
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
|
||||
}
|
||||
}
|
||||
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
||||
|
||||
try {
|
||||
TrainRunResult result;
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
String uuid = String.valueOf(params.get("uuid"));
|
||||
int epoch = (int) params.get("epoch");
|
||||
|
||||
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
|
||||
result = dockerTrainService.runEvalSync(evalReq, containerName);
|
||||
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||
}
|
||||
|
||||
ModelTrainJobDto latest =
|
||||
modelTrainJobCoreService
|
||||
.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
|
||||
|
||||
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2Success(modelId);
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1Success(modelId);
|
||||
}
|
||||
|
||||
} else {
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||
} else {
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.toString());
|
||||
|
||||
if ("EVAL".equals(params.get("jobType"))) {
|
||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||
} else {
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null || paramsJson.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
|
||||
}
|
||||
}
|
||||
@@ -199,9 +199,6 @@ public class UploadDto {
|
||||
private String fileName;
|
||||
|
||||
public double getUploadRate() {
|
||||
if (this.chunkTotalIndex == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return (double) (this.chunkIndex + 1) / (this.chunkTotalIndex + 1) * 100.0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,8 +67,8 @@ public class UploadService {
|
||||
String status = "UPLOADING";
|
||||
|
||||
if (uploadDivi.equals("dataset")) {
|
||||
tmpDataSetDir = datasetTmpDir + uuid + "/";
|
||||
fianlDir = datasetDir + uuid + "/";
|
||||
tmpDataSetDir = datasetTmpDir;
|
||||
fianlDir = datasetDir;
|
||||
}
|
||||
|
||||
// 리턴용 파일 값
|
||||
@@ -234,8 +234,8 @@ public class UploadService {
|
||||
try {
|
||||
FIleChecker.deleteFolder(tmpDir);
|
||||
// 108 에서 86 서버로 이동
|
||||
log.info("################# server move 108 -> 86");
|
||||
FIleChecker.uploadTo86(outputPath);
|
||||
// log.info("################# server move 108 -> 86");
|
||||
// FIleChecker.uploadTo86(outputPath);
|
||||
} catch (Exception e) {
|
||||
log.warn("tmpDir delete failed (merge already succeeded): tmpDir={}", tmpDir, e);
|
||||
}
|
||||
|
||||
@@ -57,3 +57,12 @@ file:
|
||||
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:love_latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
ipcHost: true
|
||||
|
||||
@@ -4,7 +4,7 @@ spring:
|
||||
on-profile: prod
|
||||
|
||||
jpa:
|
||||
show-sql: false
|
||||
show-sql: true
|
||||
hibernate:
|
||||
ddl-auto: validate
|
||||
properties:
|
||||
@@ -12,34 +12,45 @@ spring:
|
||||
default_batch_fetch_size: 100 # ✅ 성능 - N+1 쿼리 방지
|
||||
order_updates: true # ✅ 성능 - 업데이트 순서 정렬로 데드락 방지
|
||||
use_sql_comments: true # ⚠️ 선택 - SQL에 주석 추가 (디버깅용)
|
||||
format_sql: true # ⚠️ 선택 - SQL 포맷팅 (가독성)
|
||||
|
||||
datasource:
|
||||
url: jdbc:postgresql://10.100.0.10:25432/temp
|
||||
username: temp
|
||||
password: temp123!
|
||||
url: jdbc:postgresql://127.0.01:15432/kamco_training_db
|
||||
# url: jdbc:postgresql://localhost:15432/kamco_training_db
|
||||
username: kamco_training_user
|
||||
password: kamco_training_user_2025_!@#
|
||||
hikari:
|
||||
minimum-idle: 10
|
||||
maximum-pool-size: 20
|
||||
connection-timeout: 60000 # 60초 연결 타임아웃
|
||||
idle-timeout: 300000 # 5분 유휴 타임아웃
|
||||
max-lifetime: 1800000 # 30분 최대 수명
|
||||
leak-detection-threshold: 60000 # 연결 누수 감지
|
||||
|
||||
transaction:
|
||||
default-timeout: 300 # 5분 트랜잭션 타임아웃
|
||||
|
||||
jwt:
|
||||
secret: "kamco_token_prod_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a"
|
||||
secret: "kamco_token_dev_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a"
|
||||
access-token-validity-in-ms: 86400000 # 1일
|
||||
refresh-token-validity-in-ms: 604800000 # 7일
|
||||
|
||||
token:
|
||||
refresh-cookie-name: kamco # 개발용 쿠키 이름
|
||||
refresh-cookie-secure: true # 로컬 http 테스트면 false
|
||||
refresh-cookie-secure: false # 로컬 http 테스트면 false
|
||||
|
||||
springdoc:
|
||||
swagger-ui:
|
||||
persist-authorization: true # 스웨거 새로고침해도 토큰 유지, 로컬스토리지에 저장
|
||||
|
||||
member:
|
||||
init_password: kamco1234!
|
||||
|
||||
swagger:
|
||||
local-port: 9080
|
||||
|
||||
file:
|
||||
sync-root-dir: /app/original-images/
|
||||
sync-tmp-dir: ${file.sync-root-dir}tmp/
|
||||
sync-file-extention: tfw,tif
|
||||
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
ipcHost: true
|
||||
|
||||
@@ -9667,7 +9667,7 @@ INSERT INTO public.tb_audit_log VALUES (1813, 3, 'CREATE', 'SUCCESS', 'SYSTEM',
|
||||
INSERT INTO public.tb_audit_log VALUES (1814, 3, 'READ', 'SUCCESS', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:44:49.832926+00', NULL, '', '/api/models/train');
|
||||
INSERT INTO public.tb_audit_log VALUES (1815, NULL, 'CREATE', 'FAILED', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:53:40.899603+00', 467, '{
|
||||
"username": "1234567",
|
||||
"password":"****"
|
||||
"password":"****"
|
||||
}', '/api/auth/signin');
|
||||
INSERT INTO public.tb_audit_log VALUES (1816, NULL, 'CREATE', 'SUCCESS', 'SYSTEM', '127.0.0.1', NULL, '2025-12-26 00:55:27.731595+00', NULL, '{
|
||||
"username": "1234567",
|
||||
@@ -30396,6 +30396,8 @@ ALTER TABLE ONLY public.tb_menu
|
||||
ADD CONSTRAINT fksw914diut87r7lfykekc7xm2a FOREIGN KEY (parent_menu_uid) REFERENCES public.tb_menu(menu_uid);
|
||||
|
||||
|
||||
|
||||
|
||||
-- Completed on 2025-12-26 16:11:11 KST
|
||||
|
||||
--
|
||||
@@ -30404,3 +30406,5 @@ ALTER TABLE ONLY public.tb_menu
|
||||
|
||||
\unrestrict IYrUYfSgA4Fo2gubHcb84jDXfbBZEIiOZnyLtZgnMi641GaRQa5QDogarpTr7IG
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user