128 Commits

Author SHA1 Message Date
265813e6f7 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-03-10 14:19:40 +09:00
8190a6e9c8 unzip 2026-03-10 14:19:22 +09:00
e9f8bb37fa spotless 적용 2026-03-03 23:06:45 +09:00
f3c822587f spotless 적용 2026-03-03 23:02:53 +09:00
335f0dbb9b spotless 적용 2026-03-03 23:01:22 +09:00
69eaba1a83 하드링크 수정 2026-03-03 22:51:10 +09:00
365ad81cad 리커버리 삭제 2026-02-28 01:24:34 +09:00
9dfa54fbf9 리커버리 추가 2026-02-28 01:01:38 +09:00
12f6bb7154 하드링크 수정 2026-02-27 23:31:04 +09:00
aa3af4e9d0 하드링크 로그 추가 2026-02-27 23:12:00 +09:00
7ca37bf1e4 하드링크 로그 추가 2026-02-27 22:51:27 +09:00
901dde066d 하이파라미터 상세조회 수정 2026-02-24 16:54:58 +09:00
cb0a38274a val_interval 기본값 1로 수정 2026-02-24 16:24:16 +09:00
b8194df9ae 학습 실패여부 확인 기능 추가 2026-02-24 15:43:35 +09:00
7c5f07683e 학습 실패여부 확인 기능 추가 2026-02-24 15:10:48 +09:00
159fb281d4 최근 사용일시 업데이트 2026-02-23 15:40:24 +09:00
97192ff811 최근 사용일시 업데이트 2026-02-23 15:39:25 +09:00
4f3fb675be 하이퍼 파라미터 사용회수 카운트 기능 추가 및 조회 수정 2026-02-23 15:37:28 +09:00
e6caea05b3 하이퍼 파라미터 사용회수 카운트 기능 추가 및 조회 수정 2026-02-23 15:19:40 +09:00
fd63824edc 하이퍼 파라미터 미사용 컬럼제거, 사용횟수 컬럼 추가 2026-02-23 14:30:29 +09:00
8a44df26b8 하이퍼 파라미터 상세조회 삭제여부 조건 제거 2026-02-23 14:17:34 +09:00
cb97c5e59e 하이퍼 파라미터 dto 주석 수정 2026-02-23 14:13:25 +09:00
8f75b16dc6 학습실행 주석 추가 2026-02-23 12:30:54 +09:00
c2978e41c2 전이학습 상세 수정 2026-02-20 18:34:32 +09:00
07429dbe8e 전이학습 상세 수정 2026-02-20 18:22:19 +09:00
83859bb9fe 전이학습 상세 - before dataset 추가 2026-02-20 16:05:29 +09:00
564a99448c best epoch 파일 선택 수정 2026-02-20 15:41:34 +09:00
38ae6e5575 best epoch 파일 선택 수정 2026-02-20 15:31:33 +09:00
40fe98ae0c best epoch 파일 선택 수정 2026-02-20 15:15:12 +09:00
255ff10a56 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-20 14:30:42 +09:00
f674f73330 중복 수정 제거 2026-02-20 14:30:34 +09:00
db2bc32e7d test metrics insert 로직 수정 2026-02-20 14:24:23 +09:00
37786a1e44 선택한 테스트 에폭 로그 추가 2026-02-20 14:13:49 +09:00
901ea83fb7 test 선택한 에폭 log 확인 추가 2026-02-20 14:13:22 +09:00
832e1b5681 tmp 하드링크 수정 2026-02-20 13:36:48 +09:00
4f16355cda tmp 하드링크 수정 2026-02-20 12:29:57 +09:00
df46a8f79f Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-20 12:21:50 +09:00
fcd48831c5 tmp 하드링크 수정 2026-02-20 12:21:41 +09:00
62c9d73b94 test json 수정 2026-02-20 12:20:15 +09:00
68c0e634c5 ing-cnt 로직에 step2도 추가, transactional 2026-02-20 12:05:20 +09:00
ad421e3c74 비밀번호 변경 security 로직 수정 2026-02-20 11:36:21 +09:00
46db1512a6 test 실행 시 회차별 데이터 적재하기 2026-02-19 18:18:12 +09:00
2034a8fcb2 LogErrorLevel -> CodeExpose 추가 2026-02-19 17:35:15 +09:00
bf212842d8 모델학습관리 > 모델별 진행 상황 API 추가 2026-02-19 17:17:11 +09:00
d2ca94ea55 모델학습관리 > 목록 API 메모,작성자 추가로 인한 수정 2026-02-19 15:34:18 +09:00
5ddf6dfeeb 모델학습 2단계 패키징 시작,종료일시,상태 로직 추가 2026-02-19 14:43:14 +09:00
5e13c0b396 공통코드 common-code 로 prefix 변경 2026-02-19 11:38:56 +09:00
435f60dcac 로그관리 로직 커밋 2026-02-19 11:13:40 +09:00
5f5eabca19 압축해제 시, 동일 폴더가 있으면 삭제 후 재업로드 2026-02-18 16:36:46 +09:00
413631840f 학습데이터 다운로드 security 제외하기 2026-02-18 16:29:28 +09:00
c7f63d1ad1 압축 해제한 폴더의 갯수 맞는지 log 찍기 + 갯수 맞지 않으면 exception 리턴 2026-02-18 16:22:32 +09:00
7529d23488 업로드 시 uid로 중복체크 -> 삭제인 row는 제외하기 2026-02-18 15:40:14 +09:00
cb3e51d712 업로드 시 exception 메세지 처리, 에폭 10 이상으로 실행되게 수정 2026-02-18 15:28:29 +09:00
99a4597b5f train 결과 +1 했던 거 제거하기 2026-02-18 14:50:53 +09:00
d9da0d4610 1단계 실행 시, 시작시간 update 추가 2026-02-18 13:05:59 +09:00
22c481556c 하이퍼 파라미터 수정 2026-02-13 15:00:30 +09:00
0798b352c7 하이퍼 파라미터 수정 2026-02-13 14:46:59 +09:00
5b074bdb81 하이퍼 파라미터 수정 2026-02-13 14:42:39 +09:00
28919345c2 하이퍼 파라미터 수정 2026-02-13 14:37:44 +09:00
aa0552aaa7 하이퍼 파라미터 수정 2026-02-13 14:30:45 +09:00
5d0aca14a6 사용가능 용량 API 수정 2026-02-13 14:19:09 +09:00
af8d59ddfa 이어하기 수정 2026-02-13 14:08:35 +09:00
4f24e09c57 이어하기 수정 2026-02-13 14:04:33 +09:00
4da477706f 이어하기 수정 2026-02-13 13:57:01 +09:00
a070566048 이어하기 로그 수정 2026-02-13 13:26:54 +09:00
a5b3ae613f 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 13:23:53 +09:00
979af088be 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 13:17:27 +09:00
e5a1cab36b Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 12:53:04 +09:00
bb67996742 text 2026-02-13 12:53:00 +09:00
1981d6d1ce Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-13 12:53:00 +09:00
47f4ffd4db 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:52:11 +09:00
195856b846 flush 추가해보기 2026-02-13 12:46:54 +09:00
124da48e51 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:38:55 +09:00
02724e9508 주석한거 원복 2026-02-13 12:24:50 +09:00
7ed91ccab9 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 12:23:12 +09:00
a7c13b985d responsePath 셋팅 삭제 2026-02-13 12:23:01 +09:00
352a28b87f 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:20:48 +09:00
bf8515163c 주석 처리 2026-02-13 12:10:08 +09:00
2691f6ce16 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:00:42 +09:00
7e5aa5e713 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 11:57:48 +09:00
060a815e1c 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 11:55:35 +09:00
1eb4d04779 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 10:50:35 +09:00
f30c0c6d45 다운로드 시 Access-Control-Expose-Headers 추가 2026-02-13 10:50:28 +09:00
12994aab60 파일 count 기능 추가 2026-02-13 10:44:49 +09:00
11d3afe295 파일 count 기능 추가 2026-02-13 10:38:24 +09:00
1e62a8b097 학습실행 step1 할 때 best epoch 업데이트 2026-02-13 10:15:04 +09:00
26a4623aa8 학습데이터 목록 파일 단위 MB 나오게 하기 2026-02-13 09:42:36 +09:00
ce6e4f5aea tmp 파일 링크 수정 2026-02-13 09:10:45 +09:00
c2215836c0 tmp 파일 링크 수정 2026-02-13 08:44:17 +09:00
8c19c996f7 tmp 파일 링크 수정 2026-02-13 08:33:36 +09:00
b5ce3ab1fb 이어하기 수정 2026-02-12 23:01:56 +09:00
e1ceb769dd 학습데이터 다운로드 파일 정보 API 추가 2026-02-12 22:47:13 +09:00
4219b88fb3 학습데이터 다운로드 API 추가 2026-02-12 22:25:55 +09:00
4f94c99b64 이어하기 수정 2026-02-12 22:09:54 +09:00
d42e1afbd4 스케줄러 api 수동 호출 2026-02-12 21:51:53 +09:00
b3b8016673 csv 결과 받아오는 것 변경 2026-02-12 21:45:38 +09:00
79e8259f28 파라미터 변경 2026-02-12 21:30:03 +09:00
032c82c2f0 file 경로 넣기 2026-02-12 21:17:10 +09:00
6204a6e5fa Merge remote-tracking branch 'origin/develop' into feat/training_260202
# Conflicts:
#	src/main/resources/application-prod.yml
2026-02-12 21:15:21 +09:00
4d9c9a86b4 패키징 zip파일 만들기 커밋 2026-02-12 21:09:40 +09:00
83204abfe9 Merge pull request '성공시 csv 파일 테이블에 저장 연결' (#74) from feat/training_260202 into develop
Reviewed-on: #74
2026-02-12 21:01:48 +09:00
5b682c1386 성공시 csv 파일 테이블에 저장 연결 2026-02-12 21:01:24 +09:00
452494d44d Merge pull request '테스트 실행 경로 수정' (#73) from feat/training_260202 into develop
Reviewed-on: #73
2026-02-12 20:49:30 +09:00
8ada26448b 테스트 실행 경로 수정 2026-02-12 20:49:14 +09:00
e442f105bc Merge pull request '도커명 변경' (#72) from feat/training_260202 into develop
Reviewed-on: #72
2026-02-12 20:37:00 +09:00
5e0a771848 도커명 변경 2026-02-12 20:36:38 +09:00
b4c2685059 Merge pull request '도커 설정 추가' (#71) from feat/training_260202 into develop
Reviewed-on: #71
2026-02-12 20:25:16 +09:00
e238f3ca88 도커 설정 추가 2026-02-12 20:24:57 +09:00
97b06eb3b3 Merge pull request '임시파일생성 경로 수정' (#70) from feat/training_260202 into develop
Reviewed-on: #70
2026-02-12 20:03:32 +09:00
ad32ca18ca 임시파일생성 경로 수정 2026-02-12 20:03:03 +09:00
98a1283ebe Merge pull request '임시파일생성 경로 수정' (#69) from feat/training_260202 into develop
Reviewed-on: #69
2026-02-12 19:36:06 +09:00
a10fccaae3 임시파일생성 경로 수정 2026-02-12 19:35:47 +09:00
c3c9191d9d Merge pull request 'hyperparam_with_modeltype' (#68) from feat/dean/hyperparam_with_modelType-bug into develop
Reviewed-on: #68
2026-02-12 19:30:29 +09:00
9fd5a15a72 hyperparam_with_modeltype 2026-02-12 19:30:08 +09:00
12f9de7367 hyperparam_with_modeltype 2026-02-12 19:16:24 +09:00
5455da1e96 hyperparam_with_modeltype 2026-02-12 19:16:13 +09:00
9e803661cd Merge pull request 'feat/training_260202' (#67) from feat/training_260202 into develop
Reviewed-on: #67
2026-02-12 19:14:39 +09:00
b0cf9e77ec Merge branch 'develop' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into develop 2026-02-12 19:14:10 +09:00
c92426aefc Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 19:14:09 +09:00
d5b2b8ecec hyperparam_with_modeltype 2026-02-12 19:14:01 +09:00
6185a18a7c 모델목록 검색 조건 상태값 변경 2026-02-12 19:13:56 +09:00
49d3e37458 Merge pull request '임시파일생성 경로 수정' (#66) from feat/training_260202 into develop
Reviewed-on: #66
2026-02-12 19:12:37 +09:00
1fb10830b9 임시파일생성 경로 수정 2026-02-12 19:11:51 +09:00
d7766edd24 Merge pull request 'return 형식 수정' (#65) from feat/training_260202 into develop
Reviewed-on: #65
2026-02-12 18:59:37 +09:00
0bc4453c9c hyperparam_with_modeltype 2026-02-12 18:56:32 +09:00
ae0d30e5da return 형식 수정 2026-02-12 18:55:42 +09:00
37d776dd2c Merge pull request 'hyperparam_with_modeltype' (#64) from feat/dean/hyperparam_with_modelType into develop
Reviewed-on: #64
2026-02-12 18:50:32 +09:00
0c34ea7dcb hyperparam_with_modeltype 2026-02-12 18:48:14 +09:00
89 changed files with 4231 additions and 617 deletions

View File

@@ -1,6 +1,11 @@
# Stage 1: Build stage (gradle build는 Jenkins에서 이미 수행) # Stage 1: Build stage (gradle build는 Jenkins에서 이미 수행)
FROM eclipse-temurin:21-jre-jammy FROM eclipse-temurin:21-jre-jammy
# docker CLI 설치 (컨테이너에서 호스트 Docker 제어용) 260212 추가
RUN apt-get update && \
apt-get install -y --no-install-recommends docker.io ca-certificates && \
rm -rf /var/lib/apt/lists/*
# 작업 디렉토리 설정 # 작업 디렉토리 설정
WORKDIR /app WORKDIR /app

View File

@@ -3,6 +3,7 @@ plugins {
id 'org.springframework.boot' version '3.5.7' id 'org.springframework.boot' version '3.5.7'
id 'io.spring.dependency-management' version '1.1.7' id 'io.spring.dependency-management' version '1.1.7'
id 'com.diffplug.spotless' version '6.25.0' id 'com.diffplug.spotless' version '6.25.0'
id 'idea'
} }
group = 'com.kamco.cd' group = 'com.kamco.cd'
@@ -21,11 +22,23 @@ configurations {
} }
} }
// QueryDSL 생성된 소스 디렉토리 정의
def generatedSourcesDir = file("$buildDir/generated/sources/annotationProcessor/java/main")
repositories { repositories {
mavenCentral() mavenCentral()
maven { url "https://repo.osgeo.org/repository/release/" } maven { url "https://repo.osgeo.org/repository/release/" }
} }
// Gradle이 생성된 소스를 컴파일 경로에 포함하도록 설정
sourceSets {
main {
java {
srcDirs += generatedSourcesDir
}
}
}
dependencies { dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa' implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-web' implementation 'org.springframework.boot:spring-boot-starter-web'
@@ -87,6 +100,21 @@ dependencies {
implementation 'org.apache.commons:commons-csv:1.10.0' implementation 'org.apache.commons:commons-csv:1.10.0'
} }
// IntelliJ가 생성된 소스를 인식하도록 설정
idea {
module {
// 소스 디렉토리로 인식
sourceDirs += generatedSourcesDir
// Generated Sources Root로 마킹 (IntelliJ에서 특별 처리)
generatedSourceDirs += generatedSourcesDir
// 소스 및 Javadoc 다운로드
downloadJavadoc = true
downloadSources = true
}
}
configurations.configureEach { configurations.configureEach {
exclude group: 'javax.media', module: 'jai_core' exclude group: 'javax.media', module: 'jai_core'
} }
@@ -95,6 +123,21 @@ tasks.named('test') {
useJUnitPlatform() useJUnitPlatform()
} }
// 컴파일 전 생성된 소스 디렉토리 생성 보장
tasks.named('compileJava') {
doFirst {
generatedSourcesDir.mkdirs()
}
}
// 생성된 소스 정리 태스크
tasks.register('cleanGeneratedSources', Delete) {
delete generatedSourcesDir
}
tasks.named('clean') {
dependsOn 'cleanGeneratedSources'
}
bootJar { bootJar {
archiveFileName = 'ROOT.jar' archiveFileName = 'ROOT.jar'

View File

@@ -15,6 +15,7 @@ services:
- /mnt/nfs_share/model_output:/app/model-outputs - /mnt/nfs_share/model_output:/app/model-outputs
- /mnt/nfs_share/train_dataset:/app/train-dataset - /mnt/nfs_share/train_dataset:/app/train-dataset
- /home/kcomu/data:/home/kcomu/data - /home/kcomu/data:/home/kcomu/data
- /var/run/docker.sock:/var/run/docker.sock
networks: networks:
- kamco-cds - kamco-cds
restart: unless-stopped restart: unless-stopped

View File

@@ -2,10 +2,8 @@ package com.kamco.cd.training;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
@EnableAsync
@SpringBootApplication @SpringBootApplication
@EnableScheduling @EnableScheduling
public class KamcoTrainingApplication { public class KamcoTrainingApplication {

View File

@@ -23,7 +23,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private final UserDetailsService userDetailsService; private final UserDetailsService userDetailsService;
private static final AntPathMatcher PATH_MATCHER = new AntPathMatcher(); private static final AntPathMatcher PATH_MATCHER = new AntPathMatcher();
private static final String[] EXCLUDE_PATHS = { private static final String[] EXCLUDE_PATHS = {
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password" // "/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password"
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout"
}; };
@Override @Override

View File

@@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.*;
@Tag(name = "공통코드 관리", description = "공통코드 관리 API") @Tag(name = "공통코드 관리", description = "공통코드 관리 API")
@RestController @RestController
@RequiredArgsConstructor @RequiredArgsConstructor
@RequestMapping("/api/code") @RequestMapping("/api/common-code")
public class CommonCodeApiController { public class CommonCodeApiController {
private final CommonCodeService commonCodeService; private final CommonCodeService commonCodeService;

View File

@@ -0,0 +1,48 @@
package com.kamco.cd.training.common.download;
import com.kamco.cd.training.common.download.dto.DownloadSpec;
import com.kamco.cd.training.common.utils.UserUtil;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
@Service
@RequiredArgsConstructor
public class DownloadExecutor {
private final UserUtil userUtil;
public ResponseEntity<StreamingResponseBody> stream(DownloadSpec spec) throws IOException {
if (!Files.isReadable(spec.filePath())) {
return ResponseEntity.notFound().build();
}
StreamingResponseBody body =
os -> {
try (InputStream in = Files.newInputStream(spec.filePath())) {
in.transferTo(os);
os.flush();
} catch (Exception e) {
// 고용량은 중간 끊김 흔하니까 throw 금지
}
};
String fileName =
spec.downloadName() != null
? spec.downloadName()
: spec.filePath().getFileName().toString();
return ResponseEntity.ok()
.contentType(
spec.contentType() != null ? spec.contentType() : MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
.body(body);
}
}

View File

@@ -0,0 +1,19 @@
package com.kamco.cd.training.common.download;
import org.springframework.util.AntPathMatcher;
public final class DownloadPaths {
private DownloadPaths() {}
public static final String[] PATTERNS = {
"/api/inference/download/**", "/api/training-data/stage/download/**"
};
public static boolean matches(String uri) {
AntPathMatcher m = new AntPathMatcher();
for (String p : PATTERNS) {
if (m.match(p, uri)) return true;
}
return false;
}
}

View File

@@ -0,0 +1,81 @@
package com.kamco.cd.training.common.download;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.ResourceRegion;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRange;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
@Component
public class RangeDownloadResponder {
public ResponseEntity<?> buildZipResponse(
Path filePath, String downloadFileName, HttpServletRequest request) throws IOException {
if (!Files.isRegularFile(filePath)) {
return ResponseEntity.notFound().build();
}
long totalSize = Files.size(filePath);
Resource resource = new FileSystemResource(filePath);
String disposition = "attachment; filename=\"" + downloadFileName + "\"";
String rangeHeader = request.getHeader(HttpHeaders.RANGE);
// 🔥 공통 헤더 (여기 고정)
ResponseEntity.BodyBuilder base =
ResponseEntity.ok()
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
.header("Access-Control-Expose-Headers", "Content-Disposition")
.header("X-Accel-Buffering", "no");
if (rangeHeader == null || rangeHeader.isBlank()) {
return base.contentLength(totalSize).body(resource);
}
List<HttpRange> ranges;
try {
ranges = HttpRange.parseRanges(rangeHeader);
} catch (IllegalArgumentException ex) {
return ResponseEntity.status(416)
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
.header("X-Accel-Buffering", "no")
.build();
}
HttpRange range = ranges.get(0);
long start = range.getRangeStart(totalSize);
long end = range.getRangeEnd(totalSize);
if (start >= totalSize) {
return ResponseEntity.status(416)
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
.header("X-Accel-Buffering", "no")
.build();
}
long regionLength = end - start + 1;
ResourceRegion region = new ResourceRegion(resource, start, regionLength);
return ResponseEntity.status(206)
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
.header("Access-Control-Expose-Headers", "Content-Disposition")
.header("X-Accel-Buffering", "no")
.header(HttpHeaders.CONTENT_RANGE, "bytes " + start + "-" + end + "/" + totalSize)
.contentLength(regionLength)
.body(region);
}
}

View File

@@ -0,0 +1,12 @@
package com.kamco.cd.training.common.download.dto;
import java.nio.file.Path;
import java.util.UUID;
import org.springframework.http.MediaType;
public record DownloadSpec(
UUID uuid, // 다운로드 식별(로그/정책용)
Path filePath, // 실제 파일 경로
String downloadName, // 사용자에게 보일 파일명
MediaType contentType // 보통 OCTET_STREAM
) {}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.common.dto; package com.kamco.cd.training.common.dto;
import com.kamco.cd.training.common.enums.ModelType;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
@@ -11,9 +12,14 @@ import lombok.Setter;
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class HyperParam { public class HyperParam {
@Schema(description = "모델", example = "G1")
private ModelType model; // G1, G2, G3
// ------------------------- // -------------------------
// Important // Important
// ------------------------- // -------------------------
@Schema(description = "백본 네트워크", example = "large") @Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone private String backbone; // backbone

View File

@@ -0,0 +1,27 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobStatusType implements EnumType {
QUEUED("대기중"),
RUNNING("실행중"),
SUCCESS("성공"),
FAILED("실패"),
CANCELED("취소");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobType implements EnumType {
TRAIN("학습"),
TEST("테스트");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.CodeExpose; import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType; import com.kamco.cd.training.common.utils.enums.EnumType;
import java.util.Arrays;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
@@ -15,6 +16,13 @@ public enum ModelType implements EnumType {
private String desc; private String desc;
public static ModelType getValueData(String modelNo) {
return Arrays.stream(ModelType.values())
.filter(m -> m.getId().equals(modelNo))
.findFirst()
.orElse(G1);
}
@Override @Override
public String getId() { public String getId() {
return name(); return name();

View File

@@ -3,9 +3,10 @@ package com.kamco.cd.training.common.utils;
import static java.lang.String.CASE_INSENSITIVE_ORDER; import static java.lang.String.CASE_INSENSITIVE_ORDER;
import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session; import com.jcraft.jsch.Session;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.config.api.ApiResponseDto.ApiResponseCode;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File; import java.io.File;
@@ -15,6 +16,7 @@ import java.io.FileReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
@@ -39,6 +41,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.gce.geotiff.GeoTiffReader; import org.geotools.gce.geotiff.GeoTiffReader;
import org.springframework.http.HttpStatus;
import org.springframework.util.FileSystemUtils; import org.springframework.util.FileSystemUtils;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
@@ -716,13 +719,26 @@ public class FIleChecker {
public static void unzip(String fileName, String destDirectory) throws IOException { public static void unzip(String fileName, String destDirectory) throws IOException {
String zipFilePath = destDirectory + File.separator + fileName; String zipFilePath = destDirectory + File.separator + fileName;
log.info("fileName : {}", fileName);
log.info("destDirectory : {}", destDirectory);
log.info("zipFilePath : {}", zipFilePath);
// zip 이름으로 폴더 생성 (확장자 제거) // zip 이름으로 폴더 생성 (확장자 제거)
String folderName = String folderName =
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName; fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
log.info("folderName : {}", folderName);
File destDir = new File(destDirectory, folderName); File destDir = new File(destDirectory, folderName);
log.info("destDir : {}", destDir);
// 동일 폴더가 이미 있으면 삭제
log.info("111 destDir.exists() : {}", destDir.exists());
if (destDir.exists()) {
deleteDirectoryRecursively(destDir.toPath());
}
log.info("222 destDir.exists() : {}", destDir.exists());
if (!destDir.exists()) { if (!destDir.exists()) {
log.info("mkdirs : {}", destDir.exists());
destDir.mkdirs(); destDir.mkdirs();
} }
@@ -757,6 +773,11 @@ public class FIleChecker {
zipEntry = zis.getNextEntry(); zipEntry = zis.getNextEntry();
} }
zis.closeEntry(); zis.closeEntry();
} catch (IOException e) {
throw new CustomApiException(
ApiResponseCode.INTERNAL_SERVER_ERROR.getId(),
HttpStatus.INTERNAL_SERVER_ERROR,
"압축 해제 중 오류가 발생했습니다: " + e.getMessage());
} }
} }
@@ -773,92 +794,6 @@ public class FIleChecker {
return destFile; return destFile;
} }
public static void uploadTo86(Path localFile) {
String host = "192.168.2.86";
int port = 22;
String username = "kcomu";
String password = "Kamco2025!";
String remoteDir = "/home/kcomu/data/request";
Session session = null;
ChannelSftp channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(username, host, port);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
channel = (ChannelSftp) session.openChannel("sftp");
channel.connect(10_000);
// 목적지 디렉토리 이동
channel.cd(remoteDir);
// 업로드
channel.put(localFile.toString(), localFile.getFileName().toString());
} catch (Exception e) {
throw new RuntimeException("SFTP upload failed", e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static void unzipOn86Server(String zipPath, String targetDir) {
String host = "192.168.2.86";
String user = "kcomu";
String password = "Kamco2025!";
Session session = null;
ChannelExec channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(user, host, 22);
session.setPassword(password);
Properties config = new Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(10_000);
String command = "unzip -o " + zipPath + " -d " + targetDir;
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setErrStream(System.err);
InputStream in = channel.getInputStream();
channel.connect();
// 출력 읽기(선택)
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
while (br.readLine() != null) {
// 필요하면 로그
}
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
public static List<String> execCommandAndReadLines(String command) { public static List<String> execCommandAndReadLines(String command) {
List<String> result = new ArrayList<>(); List<String> result = new ArrayList<>();
@@ -906,4 +841,22 @@ public class FIleChecker {
if (session != null) session.disconnect(); if (session != null) session.disconnect();
} }
} }
/** ✅ 폴더 재귀 삭제 */
private static void deleteDirectoryRecursively(Path path) throws IOException {
if (!Files.exists(path)) return;
// 하위부터 지워야 하므로 reverse order
Files.walk(path)
.sorted(Comparator.reverseOrder())
.forEach(
p -> {
try {
Files.deleteIfExists(p);
} catch (IOException e) {
// 여기서 바로 RuntimeException으로 올려서 상위 catch(IOException)로 잡히게 함
throw new UncheckedIOException("폴더 삭제 실패: " + p.toAbsolutePath(), e);
}
});
}
} }

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.common.utils;
import jakarta.servlet.http.HttpServletRequest;
public final class HeaderUtil {
private HeaderUtil() {}
/** 특정 Header 값 조회 */
public static String get(HttpServletRequest request, String headerName) {
if (request == null || headerName == null) {
return null;
}
String value = request.getHeader(headerName);
return (value != null && !value.isBlank()) ? value : null;
}
/** 필수 Header 조회 (없으면 null) */
public static String getRequired(HttpServletRequest request, String headerName) {
return get(request, headerName);
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "trainJobExecutor")
public Executor trainJobExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); // 동시에 4개 실행
executor.setMaxPoolSize(8); // 최대 8개
executor.setQueueCapacity(200); // 대기 큐
executor.setThreadNamePrefix("train-job-");
executor.initialize();
return executor;
}
}

View File

@@ -76,11 +76,13 @@ public class SecurityConfig {
"/api/auth/logout", "/api/auth/logout",
"/swagger-ui/**", "/swagger-ui/**",
"/v3/api-docs/**", "/v3/api-docs/**",
"/api/members/*/password",
"/api/upload/chunk-upload-dataset", "/api/upload/chunk-upload-dataset",
"/api/upload/chunk-upload-complete") "/api/upload/chunk-upload-complete",
"/download_progress_test.html",
"/api/models/download/**")
.permitAll() .permitAll()
.requestMatchers("/api/members/*/password")
.authenticated()
// default // default
.anyRequest() .anyRequest()
.authenticated()) .authenticated())

View File

@@ -5,11 +5,14 @@ import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.menu.dto.MenuDto; import com.kamco.cd.training.menu.dto.MenuDto;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.util.ContentCachingRequestWrapper; import org.springframework.web.util.ContentCachingRequestWrapper;
@Slf4j
public class ApiLogFunction { public class ApiLogFunction {
// 클라이언트 IP 추출 // 클라이언트 IP 추출
@@ -34,6 +37,14 @@ public class ApiLogFunction {
return ip; return ip;
} }
public static String getXFowardedForIp(HttpServletRequest request) {
String ip = request.getHeader("X-Forwarded-For");
if (ip != null) {
ip = ip.split(",")[0].trim();
}
return ip;
}
// 사용자 ID 추출 예시 (Spring Security 기준) // 사용자 ID 추출 예시 (Spring Security 기준)
public static String getUserId(HttpServletRequest request) { public static String getUserId(HttpServletRequest request) {
try { try {
@@ -47,20 +58,20 @@ public class ApiLogFunction {
String method = request.getMethod().toUpperCase(); String method = request.getMethod().toUpperCase();
String uri = request.getRequestURI().toLowerCase(); String uri = request.getRequestURI().toLowerCase();
// URL 기반 DOWNLOAD/PRINT 분류 // URL 기반 DOWNLOAD/PRINT 분류 -> /download는 FileDownloadInterceptor로 옮김
if (uri.contains("/download") || uri.contains("/export")) { if (uri.contains("/download") || uri.contains("/export")) {
return EventType.DOWNLOAD; return EventType.DOWNLOAD;
} }
if (uri.contains("/print")) { if (uri.contains("/print")) {
return EventType.PRINT; return EventType.OTHER;
} }
// 일반 CRUD // 일반 CRUD
return switch (method) { return switch (method) {
case "POST" -> EventType.CREATE; case "POST" -> EventType.ADDED;
case "GET" -> EventType.READ; case "GET" -> EventType.LIST;
case "DELETE" -> EventType.DELETE; case "DELETE" -> EventType.REMOVE;
case "PUT", "PATCH" -> EventType.UPDATE; case "PUT", "PATCH" -> EventType.MODIFIED;
default -> EventType.OTHER; default -> EventType.OTHER;
}; };
} }
@@ -121,12 +132,22 @@ public class ApiLogFunction {
public static String getUriMenuInfo(List<MenuDto.Basic> menuList, String uri) { public static String getUriMenuInfo(List<MenuDto.Basic> menuList, String uri) {
MenuDto.Basic m = String normalizedUri = uri.replace("/api", "");
MenuDto.Basic basic =
menuList.stream() menuList.stream()
.filter(menu -> menu.getMenuApiUrl() != null && uri.contains(menu.getMenuApiUrl())) .filter(
.findFirst() menu -> menu.getMenuUrl() != null && normalizedUri.startsWith(menu.getMenuUrl()))
.max(Comparator.comparingInt(m -> m.getMenuUrl().length()))
.orElse(null); .orElse(null);
return m != null ? m.getMenuUid() : "SYSTEM"; return basic != null ? basic.getMenuUid() : "SYSTEM";
}
public static String cutRequestBody(String value) {
int MAX_LEN = 255;
if (value == null) {
return null;
}
return value.length() <= MAX_LEN ? value : value.substring(0, MAX_LEN);
} }
} }

View File

@@ -2,10 +2,17 @@ package com.kamco.cd.training.config.api;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.auth.CustomUserDetails; import com.kamco.cd.training.auth.CustomUserDetails;
import com.kamco.cd.training.common.utils.HeaderUtil;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.menu.dto.MenuDto;
import com.kamco.cd.training.menu.service.MenuService; import com.kamco.cd.training.menu.service.MenuService;
import com.kamco.cd.training.postgres.entity.AuditLogEntity; import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.kamco.cd.training.postgres.repository.log.AuditLogRepository; import com.kamco.cd.training.postgres.repository.log.AuditLogRepository;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Optional;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@@ -23,6 +30,7 @@ import org.springframework.web.util.ContentCachingRequestWrapper;
* *
* <p>createOK() → 201 CREATED ok() → 200 OK deleteOk() → 204 NO_CONTENT * <p>createOK() → 201 CREATED ok() → 200 OK deleteOk() → 204 NO_CONTENT
*/ */
@Slf4j
@RestControllerAdvice @RestControllerAdvice
public class ApiResponseAdvice implements ResponseBodyAdvice<Object> { public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
@@ -61,12 +69,27 @@ public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
if (body instanceof ApiResponseDto<?> apiResponse) { if (body instanceof ApiResponseDto<?> apiResponse) {
response.setStatusCode(apiResponse.getHttpStatus()); response.setStatusCode(apiResponse.getHttpStatus());
String ip = ApiLogFunction.getClientIp(servletRequest); String actionType = HeaderUtil.get(servletRequest, "kamco-action-type");
Long userid = null; // actionType 이 없으면 로그 저장하지 않기 || download 는 FileDownloadInterceptor 에서 하기
// (file down URL prefix 추가는 WebConfig.java 에 하기)
if (actionType == null || actionType.equalsIgnoreCase("download")) {
return body;
}
if (servletRequest.getUserPrincipal() instanceof UsernamePasswordAuthenticationToken auth String ip =
&& auth.getPrincipal() instanceof CustomUserDetails customUserDetails) { Optional.ofNullable(HeaderUtil.get(servletRequest, "kamco-user-ip"))
userid = customUserDetails.getMember().getId(); .orElseGet(() -> ApiLogFunction.getXFowardedForIp(servletRequest));
Long userid = null;
String loginAttemptId = null;
// 로그인 시도할 때
if (servletRequest.getRequestURI().contains("/api/auth/signin")) {
loginAttemptId = HeaderUtil.get(servletRequest, "kamco-login-attempt-id");
} else {
if (servletRequest.getUserPrincipal() instanceof UsernamePasswordAuthenticationToken auth
&& auth.getPrincipal() instanceof CustomUserDetails customUserDetails) {
userid = customUserDetails.getMember().getId();
}
} }
String requestBody; String requestBody;
@@ -84,17 +107,33 @@ public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
requestBody = maskSensitiveFields(requestBody); requestBody = maskSensitiveFields(requestBody);
} }
List<?> list = menuService.getFindAll();
List<MenuDto.Basic> result =
list.stream()
.map(
item -> {
if (item instanceof LinkedHashMap<?, ?> map) {
return objectMapper.convertValue(map, MenuDto.Basic.class);
} else if (item instanceof MenuDto.Basic dto) {
return dto;
} else {
throw new IllegalStateException("Unsupported cache type: " + item.getClass());
}
})
.toList();
AuditLogEntity log = AuditLogEntity log =
new AuditLogEntity( new AuditLogEntity(
userid, userid,
ApiLogFunction.getEventType(servletRequest), EventType.fromName(actionType),
ApiLogFunction.isSuccessFail(apiResponse), ApiLogFunction.isSuccessFail(apiResponse),
ApiLogFunction.getUriMenuInfo( ApiLogFunction.getUriMenuInfo(result, servletRequest.getRequestURI()),
menuService.getFindAll(), servletRequest.getRequestURI()),
ip, ip,
servletRequest.getRequestURI(), servletRequest.getRequestURI(),
requestBody, ApiLogFunction.cutRequestBody(requestBody),
apiResponse.getErrorLogUid()); apiResponse.getErrorLogUid(),
null,
loginAttemptId);
auditLogRepository.save(log); auditLogRepository.save(log);
} }

View File

@@ -14,6 +14,10 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.io.IOException;
import java.nio.file.FileStore;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -208,8 +212,15 @@ public class DatasetApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("/usable-bytes") @GetMapping("/usable-bytes")
public ApiResponseDto<DatasetStorage> getUsableBytes() { public ApiResponseDto<DatasetStorage> getUsableBytes() throws IOException {
return ApiResponseDto.ok(datasetService.getUsableBytes()); FileStore store = Files.getFileStore(Path.of("."));
long usable = store.getUsableSpace();
DatasetStorage storage = new DatasetStorage();
storage.setUsableBytes(String.valueOf(usable));
// datasetService.getUsableBytes();
return ApiResponseDto.ok(storage);
} }
@Operation(summary = "학습데이터 zip파일 등록", description = "학습데이터 zip파일 등록 합니다.") @Operation(summary = "학습데이터 zip파일 등록", description = "학습데이터 zip파일 등록 합니다.")
@@ -217,7 +228,7 @@ public class DatasetApiController {
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset( public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
@RequestBody @Valid DatasetDto.AddReq addReq) { @RequestBody @Valid DatasetDto.AddReq addReq) {
return ApiResponseDto.ok(datasetService.insertDataset(addReq)); return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
} }
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회") @Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")

View File

@@ -1,7 +1,6 @@
package com.kamco.cd.training.dataset.dto; package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister; import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType; import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
@@ -77,9 +76,16 @@ public class DatasetDto {
} }
public String getTotalSize(Long totalSize) { public String getTotalSize(Long totalSize) {
if (totalSize == null) return "0G"; if (totalSize == null || totalSize <= 0) return "0M";
double giga = totalSize / (1024.0 * 1024 * 1024); double giga = totalSize / (1024.0 * 1024 * 1024);
return String.format("%.2fG", giga);
if (giga >= 1) {
return String.format("%.2fG", giga);
} else {
double mega = totalSize / (1024.0 * 1024);
return String.format("%.2fM", mega);
}
} }
public String getStatus(String status) { public String getStatus(String status) {
@@ -227,7 +233,6 @@ public class DatasetDto {
@Getter @Getter
@Setter @Setter
@NoArgsConstructor @NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet { public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입 private String modelNo; // G1, G2, G3 모델 타입
@@ -310,6 +315,183 @@ public class DatasetDto {
} }
} }
@Schema(name = "SelectTransferDataSet", description = "전이학습 데이터셋 선택 리스트")
@Getter
@Setter
@NoArgsConstructor
public static class SelectTransferDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
private String title;
private Long roundNo;
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
private String beforeModelNo; // G1, G2, G3 모델 타입
private Long beforeDatasetId;
private UUID beforeUuid;
private String beforeDataType;
private String beforeTitle;
private Long beforeRoundNo;
private Integer beforeCompareYyyy;
private Integer beforeTargetYyyy;
private String beforeMemo;
@JsonIgnore private Long beforeClassCount;
private Integer beforeBuildingCnt;
private Integer beforeContainerCnt;
private String beforeDataTypeName;
private Long beforeWasteCnt;
private Long beforeLandCoverCnt;
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Long classCount,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Long beforeClassCount) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo != null && modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo != null && modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeClassCount = beforeClassCount;
if (beforeModelNo != null && beforeModelNo.equals(ModelType.G2.getId())) {
this.beforeWasteCnt = beforeClassCount;
} else if (beforeModelNo != null && beforeModelNo.equals(ModelType.G3.getId())) {
this.beforeLandCoverCnt = beforeClassCount;
}
}
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCnt,
Integer containerCnt,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Integer beforeBuildingCnt,
Integer beforeContainerCnt) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeBuildingCnt = beforeBuildingCnt;
this.beforeContainerCnt = beforeContainerCnt;
}
public String getDataTypeName(String groupTitleCd) {
LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd);
return type == null ? null : type.getText();
}
public String getYear() {
return this.compareYyyy + "-" + this.targetYyyy;
}
public String getBeforeYear() {
if (this.beforeCompareYyyy == null || this.beforeTargetYyyy == null) {
return null;
}
return this.beforeCompareYyyy + "-" + this.beforeTargetYyyy;
}
}
@Getter @Getter
@Setter @Setter
@NoArgsConstructor @NoArgsConstructor

View File

@@ -26,10 +26,14 @@ import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -56,6 +60,8 @@ public class DatasetService {
private String datasetDir; private String datasetDir;
private static final List<String> LABEL_DIRS = List.of("label-json", "label", "input1", "input2"); private static final List<String> LABEL_DIRS = List.of("label-json", "label", "input1", "input2");
private static final List<String> REQUIRED_DIRS = Arrays.asList("train", "val", "test");
private static final List<String> CHECK_DIRS = List.of("label", "input1", "input2");
/** /**
* 데이터셋 목록 조회 * 데이터셋 목록 조회
@@ -164,44 +170,6 @@ public class DatasetService {
} }
} }
@Deprecated
@Transactional
public ResponseObj insertDatasetTo86(@Valid AddReq addReq) {
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
// 압축 해제
FIleChecker.unzipOn86Server(
addReq.getFilePath() + addReq.getFileName(),
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 해제한 폴더 읽어서 데이터 저장
List<Map<String, Object>> list =
getUnzipDatasetFilesTo86(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "train");
int idx = 0;
for (Map<String, Object> map : list) {
datasetUid =
this.insertTrainTestData(map, addReq, idx, datasetUid, "train"); // train 데이터 insert
idx++;
}
List<Map<String, Object>> testList =
getUnzipDatasetFilesTo86(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
int testIdx = 0;
for (Map<String, Object> test : testList) {
datasetUid =
this.insertTrainTestData(test, addReq, testIdx, datasetUid, "test"); // test 데이터 insert
testIdx++;
}
datasetCoreService.updateDatasetUploadStatus(datasetUid);
return new ResponseObj(ApiResponseCode.OK, "업로드 성공하였습니다.");
}
@Transactional @Transactional
public ResponseObj insertDataset(@Valid AddReq addReq) { public ResponseObj insertDataset(@Valid AddReq addReq) {
@@ -218,6 +186,12 @@ public class DatasetService {
// 압축 해제 // 압축 해제
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath()); FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());
// 압축 해제한 폴더 하위에 train,val,test 폴더 모두 존재하는지 확인
validateTrainValTestDirs(addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 압축 해제한 폴더의 갯수 맞는지 log 찍기
validateDirFileCount(addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 해제한 폴더 읽어서 데이터 저장 // 해제한 폴더 읽어서 데이터 저장
List<Map<String, Object>> list = List<Map<String, Object>> list =
getUnzipDatasetFiles( getUnzipDatasetFiles(
@@ -367,7 +341,10 @@ public class DatasetService {
Path dir = root.resolve(dirName); Path dir = root.resolve(dirName);
if (!Files.isDirectory(dir)) { if (!Files.isDirectory(dir)) {
throw new IllegalStateException("폴더가 존재하지 않습니다 : " + dir); throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"폴더가 존재하지 않습니다. 업로드 된 파일을 확인하세요. : " + dir);
} }
try (Stream<Path> stream = Files.list(dir)) { try (Stream<Path> stream = Files.list(dir)) {
@@ -421,62 +398,6 @@ public class DatasetService {
return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType); return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType);
} }
@Deprecated
private List<Map<String, Object>> getUnzipDatasetFilesTo86(String unzipRootPath, String subDir) {
// String root = Paths.get(unzipRootPath)
// .resolve(subDir)
// .toString();
//
String root = normalizeLinuxPath(unzipRootPath + "/" + subDir);
Map<String, Map<String, Object>> grouped = new HashMap<>();
for (String dirName : LABEL_DIRS) {
String remoteDir = root + "/" + dirName;
// 1. 86 서버에서 해당 디렉토리의 파일 목록 조회
List<String> files = listFilesOn86Server(remoteDir);
if (files.isEmpty()) {
throw new IllegalStateException("폴더가 존재하지 않거나 파일이 없습니다 : " + remoteDir);
}
for (String fullPath : files) {
String fileName = Paths.get(fullPath).getFileName().toString();
String baseName = getBaseName(fileName);
Map<String, Object> data = grouped.computeIfAbsent(baseName, k -> new HashMap<>());
data.put("baseName", baseName);
if ("label-json".equals(dirName)) {
// 2. json 내용도 86 서버에서 읽어서 가져와야 함
String json = readRemoteFileAsString(fullPath);
data.put("label-json", parseJson(json));
data.put("geojson_path", fullPath);
} else {
data.put(dirName, fullPath);
}
}
}
return new ArrayList<>(grouped.values());
}
private List<String> listFilesOn86Server(String remoteDir) {
String command = "find " + escape(remoteDir) + " -maxdepth 1 -type f";
return FIleChecker.execCommandAndReadLines(command);
}
private String readRemoteFileAsString(String remoteFilePath) { private String readRemoteFileAsString(String remoteFilePath) {
String command = "cat " + escape(remoteFilePath); String command = "cat " + escape(remoteFilePath);
@@ -496,6 +417,7 @@ public class DatasetService {
} }
private String escape(String path) { private String escape(String path) {
// 쉘 커맨드에서 안전하게 사용할 수 있도록 문자열을 작은따옴표로 감싸면서, 내부의 작은따옴표를 이스케이프 처리
return "'" + path.replace("'", "'\"'\"'") + "'"; return "'" + path.replace("'", "'\"'\"'") + "'";
} }
@@ -528,4 +450,78 @@ public class DatasetService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
/** unzipRootDir: 압축 해제된 폴더 경로 (ex: /data/xxx/myzipname) */
public static void validateTrainValTestDirs(String unzipRootDir) {
Path root = Paths.get(unzipRootDir);
// 루트 폴더 자체 존재 확인
if (!Files.exists(root) || !Files.isDirectory(root)) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"압축 해제 폴더가 존재하지 않습니다: " + unzipRootDir);
}
// 필요한 폴더들 존재/디렉토리 여부 확인
List<String> missing =
REQUIRED_DIRS.stream()
.filter(d -> !Files.isDirectory(root.resolve(d)))
.collect(Collectors.toList());
if (!missing.isEmpty()) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"데이터 폴더 구조가 올바르지 않습니다. 누락된 폴더: "
+ String.join(", ", missing)
+ " (필수: train, val, test)");
}
}
public static void validateDirFileCount(String unzipRootDir) {
Path root = Paths.get(unzipRootDir);
for (String split : REQUIRED_DIRS) {
Path splitPath = root.resolve(split);
Map<String, Long> fileCountMap = new HashMap<>();
for (String subDir : CHECK_DIRS) { // input1, input2, label 폴더만 수행하기
Path subDirPath = splitPath.resolve(subDir);
if (!Files.isDirectory(subDirPath)) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + " 폴더 하위에 " + subDir + " 폴더가 존재하지 않습니다.");
}
long count;
try (Stream<Path> files = Files.list(subDirPath)) {
count = files.filter(Files::isRegularFile).count();
log.info("dir: " + subDirPath + ", count: " + count);
} catch (IOException e) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + "/" + subDir + " 파일 개수 확인 중 오류 발생");
}
fileCountMap.put(subDir, count);
}
// 모든 폴더 파일 개수가 동일한지 확인
Set<Long> uniqueCounts = new HashSet<>(fileCountMap.values());
if (uniqueCounts.size() != 1) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + " 데이터 파일 개수가 일치하지 않습니다. " + fileCountMap.toString());
}
}
}
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.hyperparam; package com.kamco.cd.training.hyperparam;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
@@ -65,7 +66,7 @@ public class HyperParamApiController {
mediaType = "application/json", mediaType = "application/json",
schema = @Schema(implementation = String.class))), schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content), @ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 수정 불가", content = @Content), @ApiResponse(responseCode = "422", description = "default는 삭제불가", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@PutMapping("/{uuid}") @PutMapping("/{uuid}")
@@ -96,10 +97,13 @@ public class HyperParamApiController {
String type, String type,
@Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false) @Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false)
LocalDate startDate, LocalDate startDate,
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false) @Parameter(description = "종료일", example = "2026-03-31") @RequestParam(required = false)
LocalDate endDate, LocalDate endDate,
@Parameter(description = "버전명", example = "HPs_0001") @RequestParam(required = false) @Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false)
String hyperVer, String hyperVer,
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
@RequestParam(required = false)
ModelType model,
@Parameter( @Parameter(
description = "정렬", description = "정렬",
example = "createdDttm desc", example = "createdDttm desc",
@@ -124,7 +128,7 @@ public class HyperParamApiController {
searchReq.setSort(sort); searchReq.setSort(sort);
searchReq.setPage(page); searchReq.setPage(page);
searchReq.setSize(size); searchReq.setSize(size);
Page<List> list = hyperParamService.getHyperParamList(searchReq); Page<List> list = hyperParamService.getHyperParamList(model, searchReq);
return ApiResponseDto.ok(list); return ApiResponseDto.ok(list);
} }
@@ -133,12 +137,12 @@ public class HyperParamApiController {
@ApiResponses( @ApiResponses(
value = { value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content), @ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 삭제 불가", content = @Content), @ApiResponse(responseCode = "422", description = "default 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
}) })
@DeleteMapping("/{uuid}") @DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteHyperParam( public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5") @Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
hyperParamService.deleteHyperParam(uuid); hyperParamService.deleteHyperParam(uuid);
@@ -160,7 +164,7 @@ public class HyperParamApiController {
}) })
@GetMapping("/{uuid}") @GetMapping("/{uuid}")
public ApiResponseDto<HyperParamDto.Basic> getHyperParam( public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5") @Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid)); return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));
@@ -179,8 +183,9 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("/init") @GetMapping("/init/{model}")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam() { public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam());
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
} }
} }

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.hyperparam.dto; package com.kamco.cd.training.hyperparam.dto;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.utils.enums.CodeExpose; import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType; import com.kamco.cd.training.common.utils.enums.EnumType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
@@ -24,9 +25,12 @@ public class HyperParamDto {
@AllArgsConstructor @AllArgsConstructor
public static class Basic { public static class Basic {
private ModelType model; // 20250212 modeltype추가
private UUID uuid; private UUID uuid;
private String hyperVer; private String hyperVer;
@JsonFormatDttm private ZonedDateTime createdDttm; @JsonFormatDttm private ZonedDateTime createdDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Integer totalUseCnt;
// ------------------------- // -------------------------
// Important // Important
@@ -98,6 +102,8 @@ public class HyperParamDto {
private Integer gpuCnt; private Integer gpuCnt;
private String gpuIds; private String gpuIds;
private Integer masterPort; private Integer masterPort;
private Boolean isDefault;
} }
@Getter @Getter
@@ -106,13 +112,12 @@ public class HyperParamDto {
@AllArgsConstructor @AllArgsConstructor
public static class List { public static class List {
private UUID uuid; private UUID uuid;
private ModelType model;
private String hyperVer; private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm; @JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm; @JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt; private String memo;
private Long m2UseCnt; private Integer totalUseCnt;
private Long m3UseCnt;
private Long totalCnt;
} }
@Getter @Getter

View File

@@ -1,8 +1,10 @@
package com.kamco.cd.training.hyperparam.service; package com.kamco.cd.training.hyperparam.service;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -20,11 +22,12 @@ public class HyperParamService {
/** /**
* 하이퍼 파라미터 목록 조회 * 하이퍼 파라미터 목록 조회
* *
* @param model
* @param req * @param req
* @return 목록 * @return 목록
*/ */
public Page<List> getHyperParamList(HyperParamDto.SearchReq req) { public Page<List> getHyperParamList(ModelType model, SearchReq req) {
return hyperParamCoreService.findByHyperVerList(req); return hyperParamCoreService.findByHyperVerList(model, req);
} }
/** /**
@@ -59,8 +62,8 @@ public class HyperParamService {
} }
/** 하이퍼파라미터 최적화 설정값 조회 */ /** 하이퍼파라미터 최적화 설정값 조회 */
public HyperParamDto.Basic getInitHyperParam() { public HyperParamDto.Basic getInitHyperParam(ModelType model) {
return hyperParamCoreService.getInitHyperParam(); return hyperParamCoreService.getInitHyperParam(model);
} }
/** /**

View File

@@ -0,0 +1,99 @@
package com.kamco.cd.training.log;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.service.AuditLogService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.LocalDate;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "감사 로그", description = "감사 로그 관리 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/logs/audit")
public class AuditLogApiController {
private final AuditLogService auditLogService;
@Operation(summary = "일자별 로그 조회")
@GetMapping("/daily")
public ApiResponseDto<Page<AuditLogDto.DailyAuditList>> getDailyLogs(
@RequestParam(required = false) LocalDate startDate,
@RequestParam(required = false) LocalDate endDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.DailyAuditList> result =
auditLogService.getLogByDaily(searchReq, startDate, endDate);
return ApiResponseDto.ok(result);
}
@Operation(summary = "일자별 로그 상세")
@GetMapping("/daily/result")
public ApiResponseDto<Page<AuditLogDto.DailyDetail>> getDailyResultLogs(
@RequestParam LocalDate logDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.DailyDetail> result = auditLogService.getLogByDailyResult(searchReq, logDate);
return ApiResponseDto.ok(result);
}
@Operation(summary = "메뉴별 로그 조회")
@GetMapping("/menu")
public ApiResponseDto<Page<AuditLogDto.MenuAuditList>> getMenuLogs(
@RequestParam(required = false) String searchValue,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.MenuAuditList> result = auditLogService.getLogByMenu(searchReq, searchValue);
return ApiResponseDto.ok(result);
}
@Operation(summary = "메뉴별 로그 상세")
@GetMapping("/menu/result")
public ApiResponseDto<Page<AuditLogDto.MenuDetail>> getMenuResultLogs(
@RequestParam String menuId,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.MenuDetail> result = auditLogService.getLogByMenuResult(searchReq, menuId);
return ApiResponseDto.ok(result);
}
@Operation(summary = "사용자별 로그 조회")
@GetMapping("/account")
public ApiResponseDto<Page<AuditLogDto.UserAuditList>> getAccountLogs(
@RequestParam(required = false) String searchValue,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.UserAuditList> result =
auditLogService.getLogByAccount(searchReq, searchValue);
return ApiResponseDto.ok(result);
}
@Operation(summary = "사용자별 로그 상세")
@GetMapping("/account/result")
public ApiResponseDto<Page<AuditLogDto.UserDetail>> getAccountResultLogs(
@RequestParam Long userUid,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.UserDetail> result = auditLogService.getLogByAccountResult(searchReq, userUid);
return ApiResponseDto.ok(result);
}
}

View File

@@ -0,0 +1,40 @@
package com.kamco.cd.training.log;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.log.service.ErrorLogService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.LocalDate;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "에러 로그", description = "에러 로그 관리 API")
@RequiredArgsConstructor
@RestController
@RequestMapping({"/api/logs/system"})
public class ErrorLogApiController {
private final ErrorLogService errorLogService;
@Operation(summary = "에러로그 조회")
@GetMapping("/error")
public ApiResponseDto<Page<ErrorLogDto.Basic>> getErrorLogs(
@RequestParam(required = false) ErrorLogDto.LogErrorLevel logErrorLevel,
@RequestParam(required = false) EventType eventType,
@RequestParam(required = false) LocalDate startDate,
@RequestParam(required = false) LocalDate endDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
ErrorLogDto.ErrorSearchReq searchReq =
new ErrorLogDto.ErrorSearchReq(
logErrorLevel, eventType, startDate, endDate, page, size, "created_dttm,desc");
Page<ErrorLogDto.Basic> result = errorLogService.findLogByError(searchReq);
return ApiResponseDto.ok(result);
}
}

View File

@@ -3,7 +3,9 @@ package com.kamco.cd.training.log.dto;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@@ -58,6 +60,7 @@ public class AuditLogDto {
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public static class AuditCommon { public static class AuditCommon {
private int readCount; private int readCount;
private int cudCount; private int cudCount;
private int printCount; private int printCount;
@@ -68,6 +71,7 @@ public class AuditLogDto {
@Schema(name = "DailyAuditList", description = "일자별 목록") @Schema(name = "DailyAuditList", description = "일자별 목록")
@Getter @Getter
public static class DailyAuditList extends AuditCommon { public static class DailyAuditList extends AuditCommon {
private final String baseDate; private final String baseDate;
public DailyAuditList( public DailyAuditList(
@@ -85,6 +89,7 @@ public class AuditLogDto {
@Schema(name = "MenuAuditList", description = "메뉴별 목록") @Schema(name = "MenuAuditList", description = "메뉴별 목록")
@Getter @Getter
public static class MenuAuditList extends AuditCommon { public static class MenuAuditList extends AuditCommon {
private final String menuId; private final String menuId;
private final String menuName; private final String menuName;
@@ -105,6 +110,7 @@ public class AuditLogDto {
@Schema(name = "UserAuditList", description = "사용자별 목록") @Schema(name = "UserAuditList", description = "사용자별 목록")
@Getter @Getter
public static class UserAuditList extends AuditCommon { public static class UserAuditList extends AuditCommon {
private final Long accountId; private final Long accountId;
private final String loginId; private final String loginId;
private final String username; private final String username;
@@ -129,6 +135,7 @@ public class AuditLogDto {
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public static class AuditDetail { public static class AuditDetail {
private Long logId; private Long logId;
private EventType eventType; private EventType eventType;
private LogDetail detail; private LogDetail detail;
@@ -137,9 +144,11 @@ public class AuditLogDto {
@Schema(name = "DailyDetail", description = "일자별 로그 상세") @Schema(name = "DailyDetail", description = "일자별 로그 상세")
@Getter @Getter
public static class DailyDetail extends AuditDetail { public static class DailyDetail extends AuditDetail {
private final String userName; private final String userName;
private final String loginId; private final String loginId;
private final String menuName; private final String menuName;
private final String logDateTime;
public DailyDetail( public DailyDetail(
Long logId, Long logId,
@@ -147,17 +156,20 @@ public class AuditLogDto {
String loginId, String loginId,
String menuName, String menuName,
EventType eventType, EventType eventType,
String logDateTime,
LogDetail detail) { LogDetail detail) {
super(logId, eventType, detail); super(logId, eventType, detail);
this.userName = userName; this.userName = userName;
this.loginId = loginId; this.loginId = loginId;
this.menuName = menuName; this.menuName = menuName;
this.logDateTime = logDateTime;
} }
} }
@Schema(name = "MenuDetail", description = "메뉴별 로그 상세") @Schema(name = "MenuDetail", description = "메뉴별 로그 상세")
@Getter @Getter
public static class MenuDetail extends AuditDetail { public static class MenuDetail extends AuditDetail {
private final String logDateTime; private final String logDateTime;
private final String userName; private final String userName;
private final String loginId; private final String loginId;
@@ -179,6 +191,7 @@ public class AuditLogDto {
@Schema(name = "UserDetail", description = "사용자별 로그 상세") @Schema(name = "UserDetail", description = "사용자별 로그 상세")
@Getter @Getter
public static class UserDetail extends AuditDetail { public static class UserDetail extends AuditDetail {
private final String logDateTime; private final String logDateTime;
private final String menuNm; private final String menuNm;
@@ -194,6 +207,7 @@ public class AuditLogDto {
@Setter @Setter
@AllArgsConstructor @AllArgsConstructor
public static class LogDetail { public static class LogDetail {
String serviceName; String serviceName;
String parentMenuName; String parentMenuName;
String menuName; String menuName;
@@ -226,4 +240,26 @@ public class AuditLogDto {
return PageRequest.of(page, size); return PageRequest.of(page, size);
} }
} }
@Getter
@Setter
public static class DownloadReq {
UUID uuid;
LocalDate startDate;
LocalDate endDate;
String searchValue;
String menuId;
String requestUri;
}
@Getter
@Setter
@AllArgsConstructor
public static class DownloadRes {
String name;
String employeeNo;
@JsonFormatDttm ZonedDateTime downloadDttm;
}
} }

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.log.dto; package com.kamco.cd.training.log.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType; import com.kamco.cd.training.common.utils.enums.EnumType;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate; import java.time.LocalDate;
@@ -77,6 +78,7 @@ public class ErrorLogDto {
} }
} }
@CodeExpose
public enum LogErrorLevel implements EnumType { public enum LogErrorLevel implements EnumType {
WARNING("Warning"), WARNING("Warning"),
ERROR("Error"), ERROR("Error"),

View File

@@ -1,22 +1,35 @@
package com.kamco.cd.training.log.dto; package com.kamco.cd.training.log.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType; import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
@CodeExpose
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum EventType implements EnumType { public enum EventType implements EnumType {
CREATE("생성"), LIST("목록"),
READ("조회"), DETAIL("상세"),
UPDATE("수정"), POPUP("팝업"),
DELETE("삭제"), STATUS("상태"),
ADDED("추가"),
MODIFIED("수정"),
REMOVE("삭제"),
DOWNLOAD("다운로드"), DOWNLOAD("다운로드"),
PRINT("출력"), LOGIN("로그인"),
OTHER("기타"); OTHER("기타");
private final String desc; private final String desc;
public static EventType fromName(String name) {
try {
return EventType.valueOf(name.toUpperCase());
} catch (Exception e) {
return OTHER;
}
}
@Override @Override
public String getId() { public String getId() {
return name(); return name();

View File

@@ -1,26 +1,39 @@
package com.kamco.cd.training.model; package com.kamco.cd.training.model;
import com.kamco.cd.training.common.download.RangeDownloadResponder;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.model.service.ModelTrainDetailService; import com.kamco.cd.training.model.service.ModelTrainDetailService;
import com.kamco.cd.training.model.service.ModelTrainMngService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.apache.coyote.BadRequestException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
@@ -32,6 +45,11 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/models") @RequestMapping("/api/models")
public class ModelTrainDetailApiController { public class ModelTrainDetailApiController {
private final ModelTrainDetailService modelTrainDetailService; private final ModelTrainDetailService modelTrainDetailService;
private final ModelTrainMngService modelTrainMngService;
private final RangeDownloadResponder rangeDownloadResponder;
@Value("${train.docker.responseDir}")
private String responseDir;
@Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 학습 진행정보", description = "학습 진행정보, 모델학습 정보 API") @Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 학습 진행정보", description = "학습 진행정보, 모델학습 정보 API")
@ApiResponses( @ApiResponses(
@@ -116,26 +134,26 @@ public class ModelTrainDetailApiController {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
} }
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API") // @Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
@ApiResponses( // @ApiResponses(
value = { // value = {
@ApiResponse( // @ApiResponse(
responseCode = "200", // responseCode = "200",
description = "조회 성공", // description = "조회 성공",
content = // content =
@Content( // @Content(
mediaType = "application/json", // mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))), // schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content), // @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) // @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) // })
@GetMapping("/transfer/detail/{uuid}") // @GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail( // public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e") // @Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@PathVariable // @PathVariable
UUID uuid) { // UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid)); // return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
} // }
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API") @Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses( @ApiResponses(
@@ -222,4 +240,90 @@ public class ModelTrainDetailApiController {
UUID uuid) { UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid)); return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
} }
@Operation(
summary = "학습데이터 파일 다운로드",
description = "학습데이터 파일 다운로드",
parameters = {
@Parameter(
name = "kamco-download-uuid",
in = ParameterIn.HEADER,
required = true,
description = "다운로드 요청 UUID",
schema =
@Schema(
type = "string",
format = "uuid",
example = "6d8d49dc-0c9d-4124-adc7-b9ca610cc394"))
})
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "학습데이터 zip파일 다운로드",
content =
@Content(
mediaType = "application/octet-stream",
schema = @Schema(type = "string", format = "binary"))),
@ApiResponse(responseCode = "404", description = "파일 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/download/{uuid}")
public ResponseEntity<?> download(@PathVariable UUID uuid, HttpServletRequest request)
throws IOException {
Basic info = modelTrainDetailService.findByModelByUUID(uuid);
Path zipPath =
Paths.get(responseDir)
.resolve(String.valueOf(info.getUuid()))
.resolve(info.getModelVer() + ".zip");
if (!Files.isRegularFile(zipPath)) {
throw new BadRequestException();
}
return rangeDownloadResponder.buildZipResponse(zipPath, info.getModelVer() + ".zip", request);
}
@Operation(summary = "모델관리 > 모델 상세 > 파일 정보", description = "모델 상세 > 파일 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/file-info/{uuid}")
public ApiResponseDto<ModelFileInfo> getModelTrainFileInfo(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainFileInfo(uuid));
}
@Operation(summary = "모델관리 > 모델별 진행 상황", description = "모델관리 > 모델별 진행 상황 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelProgressStepDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/progress/{uuid}")
public ApiResponseDto<List<ModelProgressStepDto>> findModelTrainProgressInfo(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.findModelTrainProgressInfo(uuid));
}
} }

View File

@@ -6,8 +6,10 @@ import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.model.service.ModelTrainMngService; import com.kamco.cd.training.model.service.ModelTrainMngService;
import com.kamco.cd.training.train.service.ModelTestMetricsJobService;
import com.kamco.cd.training.train.service.ModelTrainMetricsJobService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Content;
@@ -16,6 +18,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses; import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -35,6 +38,8 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/models") @RequestMapping("/api/models")
public class ModelTrainMngApiController { public class ModelTrainMngApiController {
private final ModelTrainMngService modelTrainMngService; private final ModelTrainMngService modelTrainMngService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
@Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API") @Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API")
@ApiResponses( @ApiResponses(
@@ -50,7 +55,7 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("/list") @GetMapping("/list")
public ApiResponseDto<Page<Basic>> findByModelList( public ApiResponseDto<Page<ListDto>> findByModelList(
@Parameter( @Parameter(
description = "상태코드", description = "상태코드",
example = "IN_PROGRESS", example = "IN_PROGRESS",
@@ -74,7 +79,7 @@ public class ModelTrainMngApiController {
@ApiResponses( @ApiResponses(
value = { value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content), @ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content) @ApiResponse(responseCode = "409", description = "G1_000001 삭제 불가", content = @Content)
}) })
@DeleteMapping("/{uuid}") @DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteModelTrain( public ApiResponseDto<Void> deleteModelTrain(
@@ -150,7 +155,9 @@ public class ModelTrainMngApiController {
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req)); return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
} }
@Operation(summary = "모델학습 1단계 실행중인 것이 있는지 count", description = "모델학습 1단계 실행중인 것이 있는지 count") @Operation(
summary = "모델학습 1단계/2단계 실행중인 것이 있는지 count",
description = "모델학습 1단계/2단계 실행중인 것이 있는지 count")
@ApiResponses( @ApiResponses(
value = { value = {
@ApiResponse( @ApiResponse(
@@ -167,4 +174,44 @@ public class ModelTrainMngApiController {
public ApiResponseDto<Long> findModelStep1InProgressCnt() { public ApiResponseDto<Long> findModelStep1InProgressCnt() {
return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt()); return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt());
} }
@Operation(
summary = "스케줄러 findTrainValidMetricCsvFiles",
description = "스케줄러 findTrainValidMetricCsvFiles")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/schedule-trainvalid")
public ApiResponseDto<Long> findTrainValidMetricCsvFiles() {
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
return ApiResponseDto.ok(null);
}
@Operation(summary = "스케줄러 findTestMetricCsvFiles", description = "스케줄러 findTestMetricCsvFiles")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/schedule-test")
public ApiResponseDto<Long> findTestValidMetricCsvFiles() throws IOException {
modelTestMetricsJobService.findTestValidMetricCsvFiles();
return ApiResponseDto.ok(null);
}
} }

View File

@@ -20,4 +20,25 @@ public class ModelConfigDto {
private Float testPercent; private Float testPercent;
private String memo; private String memo;
} }
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferBasic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
private Long beforeConfigId;
private Long beforeModelId;
private Integer beforeEpochCount;
private Float beforeTrainPercent;
private Float beforeValidationPercent;
private Float beforeTestPercent;
private String beforeMemo;
}
} }

View File

@@ -6,7 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums; import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm; import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration; import java.time.Duration;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
@@ -35,6 +35,7 @@ public class ModelTrainDetailDto {
@JsonFormatDttm private ZonedDateTime step2EndDttm; @JsonFormatDttm private ZonedDateTime step2EndDttm;
private String statusCd; private String statusCd;
private String trainType; private String trainType;
private UUID beforeUuid;
public String getStatusName() { public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null; if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -176,9 +177,10 @@ public class ModelTrainDetailDto {
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public static class TransferDetailDto { public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig; private ModelConfigDto.TransferBasic etcConfig;
private TransferHyperSummary modelTrainHyper; private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset; private List<SelectTransferDataSet> modelTrainDataset;
// private List<SelectDataSet> beforeTrainDataset;
} }
@Getter @Getter
@@ -245,4 +247,13 @@ public class ModelTrainDetailDto {
private Float iou; private Float iou;
private Float accuracy; private Float accuracy;
} }
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelFileInfo {
private Boolean fileExistsYn;
private String fileName;
}
} }

View File

@@ -41,6 +41,13 @@ public class ModelTrainMngDto {
private String trainType; private String trainType;
private String modelNo; private String modelNo;
private Long currentAttemptId; private Long currentAttemptId;
private String requestPath;
private String packingState;
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private Long beforeModelId;
public String getStatusName() { public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null; if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -99,6 +106,10 @@ public class ModelTrainMngDto {
public String getStep2Duration() { public String getStep2Duration() {
return formatDuration(this.step2StrtDttm, this.step2EndDttm); return formatDuration(this.step2StrtDttm, this.step2EndDttm);
} }
public String getPackingDuration() {
return formatDuration(this.packingStrtDttm, this.packingEndDttm);
}
} }
@Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터") @Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터")
@@ -209,4 +220,111 @@ public class ModelTrainMngDto {
@Schema(description = "메모", example = "메모 입니다.") @Schema(description = "메모", example = "메모 입니다.")
private String memo; private String memo;
} }
@Schema(name = "모델학습관리 목록", description = "모델학습관리 목록")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class ListDto {
private Long id;
private UUID uuid;
private String modelVer;
@JsonFormatDttm private ZonedDateTime startDttm;
@JsonFormatDttm private ZonedDateTime step1StrtDttm;
@JsonFormatDttm private ZonedDateTime step1EndDttm;
@JsonFormatDttm private ZonedDateTime step2StrtDttm;
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String step1Status;
private String step2Status;
private String statusCd;
private String trainType;
private String modelNo;
private Long currentAttemptId;
private String requestPath;
private String packingState;
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private String memo;
private String userNm;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep1StatusName() {
if (this.step1Status == null || this.step1Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep2StatusName() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getTrainTypeName() {
if (this.trainType == null || this.trainType.isBlank()) return null;
try {
return TrainType.valueOf(this.trainType).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
private String formatDuration(ZonedDateTime start, ZonedDateTime end) {
if (start == null || end == null) {
return null;
}
long totalSeconds = Math.abs(Duration.between(start, end).getSeconds());
long hours = totalSeconds / 3600;
long minutes = (totalSeconds % 3600) / 60;
long seconds = totalSeconds % 60;
return String.format("%d시간 %d분 %d초", hours, minutes, seconds);
}
public String getStep1Duration() {
return formatDuration(this.step1StrtDttm, this.step1EndDttm);
}
public String getStep2Duration() {
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
}
public String getPackingDuration() {
return formatDuration(this.packingStrtDttm, this.packingEndDttm);
}
}
@Getter
@Builder
@AllArgsConstructor
public static class ModelProgressStepDto {
private int step;
private String status;
@JsonFormatDttm private ZonedDateTime startTime;
@JsonFormatDttm private ZonedDateTime endTime;
private boolean isError;
}
} }

View File

@@ -1,18 +1,21 @@
package com.kamco.cd.training.model.service; package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService; import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.util.ArrayList; import java.util.ArrayList;
@@ -70,11 +73,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid); Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회 // config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid); ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
// 하이파라미터 정보 조회 // 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid); TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>(); List<SelectTransferDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq(); DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>(); List<Long> datasetIds = new ArrayList<>();
@@ -87,12 +90,37 @@ public class ModelTrainDetailService {
datasetReq.setIds(datasetIds); datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo()); datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals("G1")) { if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq); dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
} else { } else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq); dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
} }
// DatasetReq beforeDatasetReq = new DatasetReq();
// List<Long> beforeDatasetIds = new ArrayList<>();
// List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
// Long beforeModelId = modelInfo.getBeforeModelId();
// if (beforeModelId != null) {
// Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
// List<MappingDataset> beforeDatasets =
// modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
// for (MappingDataset before : beforeDatasets) {
// beforeDatasetIds.add(before.getDatasetId());
// }
// beforeDatasetReq.setIds(beforeDatasetIds);
// beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
// if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
// beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
// } else {
// beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
// }
// }
TransferDetailDto transferDetailDto = new TransferDetailDto(); TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo); transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary); transferDetailDto.setModelTrainHyper(hyperSummary);
@@ -116,4 +144,12 @@ public class ModelTrainDetailService {
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) { public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid); return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
} }
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelTrainDetailCoreService.findModelTrainProgressInfo(uuid);
}
} }

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType; import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType; import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -12,7 +13,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq; import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService; import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException; import com.kamco.cd.training.train.service.TrainJobService;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -30,7 +31,7 @@ public class ModelTrainMngService {
private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService;
private final HyperParamCoreService hyperParamCoreService; private final HyperParamCoreService hyperParamCoreService;
private final TmpDatasetService tmpDatasetService; private final TrainJobService trainJobService;
/** /**
* 모델학습 조회 * 모델학습 조회
@@ -38,7 +39,7 @@ public class ModelTrainMngService {
* @param searchReq 검색 조건 * @param searchReq 검색 조건
* @return 페이징 처리된 모델 목록 * @return 페이징 처리된 모델 목록
*/ */
public Page<ModelTrainMngDto.Basic> getModelList(SearchReq searchReq) { public Page<ModelTrainMngDto.ListDto> getModelList(SearchReq searchReq) {
return modelTrainMngCoreService.findByModelList(searchReq); return modelTrainMngCoreService.findByModelList(searchReq);
} }
@@ -93,22 +94,8 @@ public class ModelTrainMngService {
// 모델 config 저장 // 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig()); modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
UUID tmpUuid = UUID.randomUUID(); // 데이터셋 임시파일 생성
String raw = tmpUuid.toString().toUpperCase().replace("-", ""); trainJobService.createTmpFile(modelUuid);
List<String> uids =
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
try {
// 데이터셋 심볼링크 생성
String tmpUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(tmpUid);
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);
}
return modelUuid; return modelUuid;
} }
@@ -129,7 +116,7 @@ public class ModelTrainMngService {
* @return * @return
*/ */
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) { public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) { if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req); return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else { } else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req); return modelTrainMngCoreService.getDatasetSelectG2G3List(req);

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.service.BaseCoreService; import com.kamco.cd.training.common.service.BaseCoreService;
import com.kamco.cd.training.log.dto.AuditLogDto; import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import com.kamco.cd.training.postgres.repository.log.AuditLogRepository; import com.kamco.cd.training.postgres.repository.log.AuditLogRepository;
import java.time.LocalDate; import java.time.LocalDate;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -45,6 +46,11 @@ public class AuditLogCoreService
return auditLogRepository.findLogByAccount(searchRange, searchValue); return auditLogRepository.findLogByAccount(searchRange, searchValue);
} }
public Page<AuditLogDto.DownloadRes> findLogByAccount(
AuditLogDto.searchReq searchReq, DownloadReq downloadReq) {
return auditLogRepository.findDownloadLog(searchReq, downloadReq);
}
public Page<AuditLogDto.DailyDetail> getLogByDailyResult( public Page<AuditLogDto.DailyDetail> getLogByDailyResult(
AuditLogDto.searchReq searchRange, LocalDate logDate) { AuditLogDto.searchReq searchRange, LocalDate logDate) {
return auditLogRepository.findLogByDailyResult(searchRange, logDate); return auditLogRepository.findLogByDailyResult(searchRange, logDate);

View File

@@ -1,10 +1,12 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.dto.HyperParam; import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.exception.CustomApiException; import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository; import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
@@ -17,6 +19,7 @@ import org.springframework.stereotype.Service;
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class HyperParamCoreService { public class HyperParamCoreService {
private final HyperParamRepository hyperParamRepository; private final HyperParamRepository hyperParamRepository;
private final UserUtil userUtil; private final UserUtil userUtil;
@@ -27,11 +30,10 @@ public class HyperParamCoreService {
* @return 등록된 버전명 * @return 등록된 버전명
*/ */
public Basic createHyperParam(HyperParam createReq) { public Basic createHyperParam(HyperParam createReq) {
String firstVersion = getFirstHyperParamVersion(); String firstVersion = getFirstHyperParamVersion(createReq.getModel());
ModelHyperParamEntity entity = new ModelHyperParamEntity(); ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion); entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq); applyHyperParam(entity, createReq);
// user // user
@@ -57,7 +59,7 @@ public class HyperParamCoreService {
.findHyperParamByUuid(uuid) .findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) { if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY); throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
} }
applyHyperParam(entity, createReq); applyHyperParam(entity, createReq);
@@ -69,11 +71,112 @@ public class HyperParamCoreService {
return entity.getHyperVer(); return entity.getHyperVer();
} }
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// }
// 디폴트면 삭제불가
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository.getHyperParamByType(model).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
return hyperParamRepository.findByHyperVerList(model, req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @param model 모델 타입
* @return ver
*/
public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository
.findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model))
.orElse(model.name() + "_000001");
}
/**
* 하이퍼 파라미터의 버전을 증가시킨다.
*
* @param hyperVer 현재 버전
* @param modelType 모델 타입
* @return 증가된 버전
*/
private String increase(String hyperVer, ModelType modelType) {
String prefix = modelType.name() + "_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%06d", num + 1);
}
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) { private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
ModelType model = src.getModel();
// 하드코딩 모델별로 다른경우 250212 bbn 하드코딩
if (model == ModelType.G3) {
entity.setCropSize("512,512");
} else {
entity.setCropSize("256,256");
}
entity.setCropSize(src.getCropSize());
// Important // Important
entity.setModelType(model); // 20250212 modeltype추가
entity.setBackbone(src.getBackbone()); entity.setBackbone(src.getBackbone());
entity.setInputSize(src.getInputSize()); entity.setInputSize(src.getInputSize());
entity.setCropSize(src.getCropSize());
entity.setBatchSize(src.getBatchSize()); entity.setBatchSize(src.getBatchSize());
// Data // Data
@@ -110,79 +213,4 @@ public class HyperParamCoreService {
// memo // memo
entity.setMemo(src.getMemo()); entity.setMemo(src.getMemo());
} }
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam() {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByHyperVer("HPs_0001")
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
return hyperParamRepository.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @return ver
*/
public String getFirstHyperParamVersion() {
return hyperParamRepository
.findHyperParamVer()
.map(ModelHyperParamEntity::getHyperVer)
.map(this::increase)
.orElse("HPs_0001");
}
private String increase(String hyperVer) {
String prefix = "HPs_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%04d", num + 1);
}
} }

View File

@@ -1,7 +1,10 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.time.ZonedDateTime;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -26,4 +29,22 @@ public class ModelTestMetricsJobCoreService {
public void insertModelMetricsTest(List<Object[]> batchArgs) { public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs); modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
} }
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return modelTestMetricsJobRepository.getTestMetricPackingInfo(modelId);
}
public ModelTestFileName findModelTestFileNames(Long modelId) {
return modelTestMetricsJobRepository.findModelTestFileNames(modelId);
}
@Transactional
public void updatePackingStart(Long modelId, ZonedDateTime now) {
modelTestMetricsJobRepository.updatePackingStart(modelId, now);
}
@Transactional
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState);
}
} }

View File

@@ -8,11 +8,13 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository; import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository; import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository;
@@ -55,6 +57,12 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getModelDetailSummary(uuid); return modelDetailRepository.getModelDetailSummary(uuid);
} }
/**
* 하이퍼 파리미터 요약정보
*
* @param uuid 모델마스터 uuid
* @return
*/
public HyperSummary getByModelHyperParamSummary(UUID uuid) { public HyperSummary getByModelHyperParamSummary(UUID uuid) {
return modelDetailRepository.getByModelHyperParamSummary(uuid); return modelDetailRepository.getByModelHyperParamSummary(uuid);
} }
@@ -97,4 +105,17 @@ public class ModelTrainDetailCoreService {
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) { public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelDetailRepository.getModelTrainBestEpoch(uuid); return modelDetailRepository.getModelTrainBestEpoch(uuid);
} }
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelDetailRepository.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelDetailRepository.findModelTrainProgressInfo(uuid);
}
public Basic findByModelBeforeId(Long beforeModelId) {
ModelMasterEntity entity = modelDetailRepository.findByModelBeforeId(beforeModelId);
return entity.toDto();
}
} }

View File

@@ -1,15 +1,22 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto; import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@Log4j2
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
@Transactional(readOnly = true) @Transactional(readOnly = true)
@@ -42,17 +49,23 @@ public class ModelTrainJobCoreService {
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now()); job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
modelTrainJobRepository.save(job); modelTrainJobRepository.save(job);
modelTrainJobRepository.flush();
return job.getId(); return job.getId();
} }
/** 실행 시작 처리 */ /** 실행 시작 처리 */
@Transactional @Transactional
public void markRunning( public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) { Long jobId,
String containerName,
String logPath,
String lockedBy,
Integer totalEpoch,
String jobType) {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("RUNNING"); job.setStatusCd("RUNNING");
job.setContainerName(containerName); job.setContainerName(containerName);
@@ -60,37 +73,73 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now()); job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now()); job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy); job.setLockedBy(lockedBy);
job.setJobType(jobType);
if (totalEpoch != null) { if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch); job.setTotalEpoch(totalEpoch);
} }
} }
/** 성공 처리 */ /**
* 성공 처리
*
* @param jobId
* @param exitCode
*/
@Transactional @Transactional
public void markSuccess(Long jobId, int exitCode) { public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("SUCCESS"); job.setStatusCd("SUCCESS");
job.setExitCode(exitCode); job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
} }
/** 실패 처리 */ /**
* 실패 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional @Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) { public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("FAILED"); job.setStatusCd("FAILED");
job.setExitCode(exitCode); job.setExitCode(exitCode);
job.setErrorMessage(errorMessage); job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
}
/**
* 중단됨 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markPaused(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
} }
/** 취소 처리 */ /** 취소 처리 */
@@ -99,9 +148,40 @@ public class ModelTrainJobCoreService {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED"); job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now()); job.setFinishedDttm(ZonedDateTime.now());
} }
@Transactional
public void updateEpoch(String containerName, Integer epoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setCurrentEpoch(epoch);
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
}
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
}
/**
* 실행중인 학습이 있는지 조회
*
* @return
*/
public List<ModelTrainJobDto> findRunningJobs() {
List<ModelTrainJobEntity> entity = modelTrainJobRepository.findRunningJobs();
if (entity == null || entity.isEmpty()) {
return Collections.emptyList();
}
return entity.stream().map(ModelTrainJobEntity::toDto).toList();
}
} }

View File

@@ -29,4 +29,9 @@ public class ModelTrainMetricsJobCoreService {
public void insertModelMetricsValidation(List<Object[]> batchArgs) { public void insertModelMetricsValidation(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs); modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
} }
@Transactional
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch);
}
} }

View File

@@ -8,9 +8,10 @@ import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil; import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto; import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic; import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset; import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset;
import com.kamco.cd.training.postgres.entity.ModelConfigEntity; import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity; import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
@@ -23,6 +24,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository; import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository; import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository; import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.ArrayList; import java.util.ArrayList;
@@ -53,9 +55,10 @@ public class ModelTrainMngCoreService {
* @param searchReq 검색 조건 * @param searchReq 검색 조건
* @return 페이징 처리된 모델 목록 * @return 페이징 처리된 모델 목록
*/ */
public Page<Basic> findByModelList(ModelTrainMngDto.SearchReq searchReq) { public Page<ListDto> findByModelList(ModelTrainMngDto.SearchReq searchReq) {
Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq); // Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
return entityPage.map(ModelMasterEntity::toDto); // return entityPage.map(ModelMasterEntity::toDto);
return modelMngRepository.findByModels(searchReq);
} }
/** /**
@@ -83,9 +86,15 @@ public class ModelTrainMngCoreService {
ModelMasterEntity entity = new ModelMasterEntity(); ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity(); ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
// 최적화 파라미터는 HPs_0001 사용 // 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) { if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity =
hyperParamRepository.getHyperParamByType(modelType).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else { } else {
hyperParamEntity = hyperParamEntity =
@@ -95,6 +104,12 @@ public class ModelTrainMngCoreService {
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) { if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND); throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
} }
// 하이퍼 파라미터 사용 횟수 업데이트
hyperParamEntity.setTotalUseCnt(
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
// 최근 사용일시 업데이트
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
String modelVer = String modelVer =
String.join( String.join(
@@ -105,16 +120,8 @@ public class ModelTrainMngCoreService {
entity.setTrainType(addReq.getTrainType()); // 일반, 전이 entity.setTrainType(addReq.getTrainType()); // 일반, 전이
entity.setBeforeModelId(addReq.getBeforeModelId()); entity.setBeforeModelId(addReq.getBeforeModelId());
if (addReq.getIsStart()) { entity.setStatusCd(TrainStatusType.READY.getId());
entity.setModelStep((short) 1); entity.setStep1State(TrainStatusType.READY.getId());
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStrtDttm(ZonedDateTime.now());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
} else {
entity.setStatusCd(TrainStatusType.READY.getId());
entity.setStep1State(TrainStatusType.READY.getId());
}
entity.setCreatedUid(userUtil.getId()); entity.setCreatedUid(userUtil.getId());
ModelMasterEntity resultEntity = modelMngRepository.save(entity); ModelMasterEntity resultEntity = modelMngRepository.save(entity);
@@ -166,13 +173,10 @@ public class ModelTrainMngCoreService {
modelMngRepository modelMngRepository
.findById(modelId) .findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// 임시폴더 UID업데이트
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) { if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath()); entity.setRequestPath(req.getRequestPath());
} }
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
entity.setRequestPath(req.getResponsePath());
}
} }
/** /**
@@ -203,7 +207,10 @@ public class ModelTrainMngCoreService {
ModelConfigEntity entity = new ModelConfigEntity(); ModelConfigEntity entity = new ModelConfigEntity();
modelMasterEntity.setId(modelId); modelMasterEntity.setId(modelId);
entity.setModel(modelMasterEntity); entity.setModel(modelMasterEntity);
entity.setEpochCount(req.getEpochCnt()); entity.setEpochCount(
req.getEpochCnt() < 10
? 10
: req.getEpochCnt()); // 에폭이 10 이하이면 10으로 고정하기. 10 이상 에폭으로 해야 best 에폭 파일이 생성되어 내려옴
entity.setTrainPercent(req.getTrainingCnt()); entity.setTrainPercent(req.getTrainingCnt());
entity.setValidationPercent(req.getValidationCnt()); entity.setValidationPercent(req.getValidationCnt());
entity.setTestPercent(req.getTestCnt()); entity.setTestPercent(req.getTestCnt());
@@ -271,6 +278,13 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} }
public ModelConfigDto.TransferBasic findModelTransferConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelTransferConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/** /**
* 데이터셋 G1 목록 * 데이터셋 G1 목록
* *
@@ -281,6 +295,16 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG1List(req); return datasetRepository.getDatasetSelectG1List(req);
} }
/**
* 전이학습 데이터셋 G1 목록
*
* @param modelId 모델 Id
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
return datasetRepository.getDatasetTransferSelectG1List(modelId);
}
/** /**
* 데이터셋 G2, G3 목록 * 데이터셋 G2, G3 목록
* *
@@ -291,6 +315,18 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG2G3List(req); return datasetRepository.getDatasetSelectG2G3List(req);
} }
/**
* 전이학습 데이터셋 G2, G3 목록
*
* @param modelId 모델 Id
* @param modelNo G2, G3
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
}
/** /**
* 모델관리 조회 * 모델관리 조회
* *
@@ -317,6 +353,7 @@ public class ModelTrainMngCoreService {
master.setCurrentAttemptId(jobId); master.setCurrentAttemptId(jobId);
// 필요하면 시작시간도 여기서 찍어줌 // 필요하면 시작시간도 여기서 찍어줌
modelMngRepository.flush();
} }
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */ /** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
@@ -328,6 +365,7 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null); master.setLastError(null);
modelMngRepository.flush();
} }
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */ /** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
@@ -352,7 +390,12 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
} }
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional @Transactional
public void markError(Long modelId, String errorMessage) { public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -367,7 +410,12 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now()); master.setUpdatedDttm(ZonedDateTime.now());
} }
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional @Transactional
public void markStep2Error(Long modelId, String errorMessage) { public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -420,6 +468,7 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId)); .orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId()); entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStrtDttm(ZonedDateTime.now());
entity.setStep1StrtDttm(ZonedDateTime.now()); entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId()); entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId); entity.setCurrentAttemptId(jobId);
@@ -517,4 +566,34 @@ public class ModelTrainMngCoreService {
public Long findModelStep1InProgressCnt() { public Long findModelStep1InProgressCnt() {
return modelMngRepository.findModelStep1InProgressCnt(); return modelMngRepository.findModelStep1InProgressCnt();
} }
/**
* train 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTrainPath(modelId);
}
/**
* validation 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
return modelDatasetMapRepository.findDatasetValPath(modelId);
}
/**
* test 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTestPath(modelId);
}
} }

View File

@@ -5,6 +5,7 @@ import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType; import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.CommonCreateEntity; import com.kamco.cd.training.postgres.CommonCreateEntity;
import jakarta.persistence.*; import jakarta.persistence.*;
import java.util.UUID;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@@ -14,6 +15,7 @@ import lombok.NoArgsConstructor;
@NoArgsConstructor(access = AccessLevel.PROTECTED) @NoArgsConstructor(access = AccessLevel.PROTECTED)
@Table(name = "tb_audit_log") @Table(name = "tb_audit_log")
public class AuditLogEntity extends CommonCreateEntity { public class AuditLogEntity extends CommonCreateEntity {
@Id @Id
@GeneratedValue(strategy = GenerationType.IDENTITY) @GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "audit_log_uid", nullable = false) @Column(name = "audit_log_uid", nullable = false)
@@ -43,6 +45,12 @@ public class AuditLogEntity extends CommonCreateEntity {
@Column(name = "error_log_uid") @Column(name = "error_log_uid")
private Long errorLogUid; private Long errorLogUid;
@Column(name = "download_uuid")
private UUID downloadUuid;
@Column(name = "login_attempt_id")
private String loginAttemptId;
public AuditLogEntity( public AuditLogEntity(
Long userUid, Long userUid,
EventType eventType, EventType eventType,
@@ -51,7 +59,9 @@ public class AuditLogEntity extends CommonCreateEntity {
String ipAddress, String ipAddress,
String requestUri, String requestUri,
String requestBody, String requestBody,
Long errorLogUid) { Long errorLogUid,
UUID downloadUuid,
String loginAttemptId) {
this.userUid = userUid; this.userUid = userUid;
this.eventType = eventType; this.eventType = eventType;
this.eventStatus = eventStatus; this.eventStatus = eventStatus;
@@ -60,6 +70,31 @@ public class AuditLogEntity extends CommonCreateEntity {
this.requestUri = requestUri; this.requestUri = requestUri;
this.requestBody = requestBody; this.requestBody = requestBody;
this.errorLogUid = errorLogUid; this.errorLogUid = errorLogUid;
this.downloadUuid = downloadUuid;
this.loginAttemptId = loginAttemptId;
}
/** 파일 다운로드 이력 생성 */
public static AuditLogEntity forFileDownload(
Long userId,
String requestUri,
String menuUid,
String ip,
int httpStatus,
UUID downloadUuid) {
return new AuditLogEntity(
userId,
EventType.DOWNLOAD, // 이벤트 타입 고정
httpStatus < 400 ? EventStatus.SUCCESS : EventStatus.FAILED, // 성공 여부
menuUid,
ip,
requestUri,
null, // requestBody 없음
null, // errorLogUid 없음
downloadUuid,
null // loginAttemptId 없음
);
} }
public AuditLogDto.Basic toDto() { public AuditLogDto.Basic toDto() {

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity; package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import jakarta.persistence.*; import jakarta.persistence.*;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
@@ -191,10 +192,10 @@ public class ModelHyperParamEntity {
@Column(name = "save_best_rule", nullable = false, length = 10) @Column(name = "save_best_rule", nullable = false, length = 10)
private String saveBestRule = "greater"; private String saveBestRule = "greater";
/** Default: 10 */ /** Default: 1 */
@NotNull @NotNull
@Column(name = "val_interval", nullable = false) @Column(name = "val_interval", nullable = false)
private Integer valInterval = 10; private Integer valInterval = 1;
/** Default: 400 */ /** Default: 400 */
@NotNull @NotNull
@@ -302,20 +303,24 @@ public class ModelHyperParamEntity {
@Column(name = "last_used_dttm") @Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm; private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt") @Column(name = "model_type")
private Long m1UseCnt = 0L; @Enumerated(EnumType.STRING)
private ModelType modelType;
@Column(name = "m2_use_cnt") @Column(name = "default_param")
private Long m2UseCnt = 0L; private Boolean isDefault = false;
@Column(name = "m3_use_cnt") @Column(name = "total_use_cnt")
private Long m3UseCnt = 0L; private Integer totalUseCnt = 0;
public HyperParamDto.Basic toDto() { public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic( return new HyperParamDto.Basic(
this.modelType,
this.uuid, this.uuid,
this.hyperVer, this.hyperVer,
this.createdDttm, this.createdDttm,
this.lastUsedDttm,
this.totalUseCnt,
// ------------------------- // -------------------------
// Important // Important
// ------------------------- // -------------------------
@@ -385,6 +390,7 @@ public class ModelHyperParamEntity {
// ------------------------- // -------------------------
this.gpuCnt, this.gpuCnt,
this.gpuIds, this.gpuIds,
this.masterPort); this.masterPort,
this.isDefault);
} }
} }

View File

@@ -112,6 +112,15 @@ public class ModelMasterEntity {
@Column(name = "response_path") @Column(name = "response_path")
private String responsePath; private String responsePath;
@Column(name = "packing_state")
private String packingState;
@Column(name = "packing_strt_dttm")
private ZonedDateTime packingStrtDttm;
@Column(name = "packing_end_dttm")
private ZonedDateTime packingEndDttm;
public ModelTrainMngDto.Basic toDto() { public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic( return new ModelTrainMngDto.Basic(
this.id, this.id,
@@ -127,6 +136,11 @@ public class ModelMasterEntity {
this.statusCd, this.statusCd,
this.trainType, this.trainType,
this.modelNo, this.modelNo,
this.currentAttemptId); this.currentAttemptId,
this.requestPath,
this.packingState,
this.packingStrtDttm,
this.packingEndDttm,
this.beforeModelId);
} }
} }

View File

@@ -0,0 +1,42 @@
package com.kamco.cd.training.postgres.entity;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import java.time.OffsetDateTime;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_test_training_run")
public class ModelTestTrainingRunEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "tsr_id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@Column(name = "attempt_no")
private Integer attemptNo;
@Column(name = "job_id")
private Long jobId;
@Column(name = "epoch")
private Integer epoch;
@ColumnDefault("now()")
@Column(name = "created_dttm")
private OffsetDateTime createdDttm;
}

View File

@@ -57,8 +57,7 @@ public class ModelTrainJobEntity {
@Column(name = "exit_code") @Column(name = "exit_code")
private Integer exitCode; private Integer exitCode;
@Size(max = 2000) @Column(name = "error_message", columnDefinition = "TEXT")
@Column(name = "error_message", length = 2000)
private String errorMessage; private String errorMessage;
@ColumnDefault("now()") @ColumnDefault("now()")
@@ -84,6 +83,9 @@ public class ModelTrainJobEntity {
@Column(name = "current_epoch") @Column(name = "current_epoch")
private Integer currentEpoch; private Integer currentEpoch;
@Column(name = "job_type")
private String jobType;
public ModelTrainJobDto toDto() { public ModelTrainJobDto toDto() {
return new ModelTrainJobDto( return new ModelTrainJobDto(
this.id, this.id,

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@@ -17,6 +18,10 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req); List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId);
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(Long modelId, String modelNo);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req); List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy); Long getDatasetMaxStage(int compareYyyy, int targetYyyy);

View File

@@ -1,14 +1,20 @@
package com.kamco.cd.training.postgres.repository.dataset; package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity; import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq; import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet; import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity; import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity; import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.BooleanBuilder; import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder; import com.querydsl.core.types.dsl.CaseBuilder;
@@ -67,7 +73,11 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
// Count 쿼리 별도 실행 (null safe handling) // Count 쿼리 별도 실행 (null safe handling)
long total = long total =
Optional.ofNullable( Optional.ofNullable(
queryFactory.select(dataset.count()).from(dataset).where(builder).fetchOne()) queryFactory
.select(dataset.count())
.from(dataset)
.where(builder.and(dataset.deleted.isFalse()))
.fetchOne())
.orElse(0L); .orElse(0L);
return new PageImpl<>(content, pageable, total); return new PageImpl<>(content, pageable, total);
@@ -138,6 +148,103 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch(); .fetch();
} }
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum(),
// ===== before (join으로) =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum()))
.from(modelMasterEntity)
// ===== 현재 dataset join =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before 모델 join =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId))
.groupBy(
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override @Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) { public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
@@ -201,6 +308,116 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch(); .fetch();
} }
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
// before join용
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> elseCnt =
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> selectedCnt = ModelType.G2.getId().equals(modelNo) ? wasteCnt : elseCnt;
// before도 동일 로직으로 cnt 계산
NumberExpression<Long> beforeWasteCnt =
beforeObj.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> beforeElseCnt =
new CaseBuilder()
.when(beforeObj.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> beforeSelectedCnt =
ModelType.G2.getId().equals(modelNo) ? beforeWasteCnt : beforeElseCnt;
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo, // modelNo 파라미터 사용 (req.getModelNo() 제거)
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
selectedCnt, // classCount 자리에 들어가는 cnt (Long)
// ===== before =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
beforeSelectedCnt))
.from(modelMasterEntity)
// ===== 현재 dataset =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before dataset =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId).and(builder))
// sum() 때문에 groupBy 필요
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override @Override
public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) { public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) {
return queryFactory return queryFactory
@@ -239,7 +456,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
return queryFactory return queryFactory
.select(dataset.id) .select(dataset.id)
.from(dataset) .from(dataset)
.where(dataset.uid.eq(mngRegDto.getUid())) .where(dataset.uid.eq(mngRegDto.getUid()), dataset.deleted.isFalse())
.fetchOne(); .fetchOne();
} }
@@ -253,7 +470,7 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
return queryFactory return queryFactory
.select(dataset.id.count()) .select(dataset.id.count())
.from(dataset) .from(dataset)
.where(dataset.uid.eq(uid)) .where(dataset.uid.eq(uid), dataset.deleted.isFalse())
.fetchOne(); .fetchOne();
} }
} }

View File

@@ -1,7 +1,10 @@
package com.kamco.cd.training.postgres.repository.hyperparam; package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -13,11 +16,41 @@ public interface HyperParamRepositoryCustom {
* *
* @return * @return
*/ */
@Deprecated
Optional<ModelHyperParamEntity> findHyperParamVer(); Optional<ModelHyperParamEntity> findHyperParamVer();
/**
* 모델 타입별 마지막 버전 조회
*
* @param modelType 모델 타입
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType);
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer); Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
/**
* 하이퍼 파라미터 상세조회
*
* @param uuid
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid); Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req); /**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
/**
* 하이퍼 파라미터 모델타입으로 조회
*
* @param modelType
* @return
*/
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
} }

View File

@@ -2,12 +2,13 @@ package com.kamco.cd.training.postgres.repository.hyperparam;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder; import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery; import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId; import java.time.ZoneId;
@@ -41,6 +42,23 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
.fetchOne()); .fetchOne());
} }
@Override
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1)
.fetchOne());
}
@Override @Override
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) { public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
@@ -63,17 +81,22 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
queryFactory queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid))) .where(modelHyperParamEntity.uuid.eq(uuid))
.fetchOne()); .fetchOne());
} }
@Override @Override
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) { public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
Pageable pageable = req.toPageable(); Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.isFalse()); builder.and(modelHyperParamEntity.delYn.isFalse());
if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model));
}
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) { if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전 // 버전
builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer())); builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer()));
@@ -96,26 +119,18 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
} }
} }
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query = JPAQuery<HyperParamDto.List> query =
queryFactory queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
HyperParamDto.List.class, HyperParamDto.List.class,
modelHyperParamEntity.uuid, modelHyperParamEntity.uuid,
modelHyperParamEntity.modelType.as("model"),
modelHyperParamEntity.hyperVer, modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm, modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm, modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.m1UseCnt, modelHyperParamEntity.memo,
modelHyperParamEntity.m2UseCnt, modelHyperParamEntity.totalUseCnt))
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where(builder); .where(builder);
@@ -140,8 +155,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
asc asc
? modelHyperParamEntity.lastUsedDttm.asc() ? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc()); : modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" ->
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc()); query.orderBy(
asc
? modelHyperParamEntity.totalUseCnt.asc()
: modelHyperParamEntity.totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc()); default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
} }
@@ -161,4 +179,17 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
return new PageImpl<>(content, pageable, totalCount); return new PageImpl<>(content, pageable, totalCount);
} }
@Override
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
return queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.fetch();
}
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.repository.log; package com.kamco.cd.training.postgres.repository.log;
import com.kamco.cd.training.log.dto.AuditLogDto; import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import java.time.LocalDate; import java.time.LocalDate;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -15,6 +16,9 @@ public interface AuditLogRepositoryCustom {
Page<AuditLogDto.UserAuditList> findLogByAccount( Page<AuditLogDto.UserAuditList> findLogByAccount(
AuditLogDto.searchReq searchReq, String searchValue); AuditLogDto.searchReq searchReq, String searchValue);
Page<AuditLogDto.DownloadRes> findDownloadLog(
AuditLogDto.searchReq searchReq, DownloadReq downloadReq);
Page<AuditLogDto.DailyDetail> findLogByDailyResult( Page<AuditLogDto.DailyDetail> findLogByDailyResult(
AuditLogDto.searchReq searchReq, LocalDate logDate); AuditLogDto.searchReq searchReq, LocalDate logDate);

View File

@@ -6,32 +6,42 @@ import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
import static com.kamco.cd.training.postgres.entity.QMenuEntity.menuEntity; import static com.kamco.cd.training.postgres.entity.QMenuEntity.menuEntity;
import com.kamco.cd.training.log.dto.AuditLogDto; import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import com.kamco.cd.training.log.dto.AuditLogDto.searchReq;
import com.kamco.cd.training.log.dto.ErrorLogDto; import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventStatus; import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType; import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.kamco.cd.training.postgres.entity.QMenuEntity; import com.kamco.cd.training.postgres.entity.QMenuEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.*; import com.querydsl.core.types.dsl.*;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import io.micrometer.common.util.StringUtils; import io.micrometer.common.util.StringUtils;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.ZoneId;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@Repository @Repository
@RequiredArgsConstructor public class AuditLogRepositoryImpl extends QuerydslRepositorySupport
public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom { implements AuditLogRepositoryCustom {
private static final ZoneId ZONE = ZoneId.of("Asia/Seoul");
private final JPAQueryFactory queryFactory; private final JPAQueryFactory queryFactory;
private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)"); private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)");
public AuditLogRepositoryImpl(JPAQueryFactory queryFactory) {
super(AuditLogEntity.class);
this.queryFactory = queryFactory;
}
@Override @Override
public Page<AuditLogDto.DailyAuditList> findLogByDaily( public Page<AuditLogDto.DailyAuditList> findLogByDaily(
AuditLogDto.searchReq searchReq, LocalDate startDate, LocalDate endDate) { AuditLogDto.searchReq searchReq, LocalDate startDate, LocalDate endDate) {
@@ -87,7 +97,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
.from(auditLogEntity) .from(auditLogEntity)
.leftJoin(menuEntity) .leftJoin(menuEntity)
.on(auditLogEntity.menuUid.eq(menuEntity.menuUid)) .on(auditLogEntity.menuUid.eq(menuEntity.menuUid))
.where(menuNameEquals(searchValue)) .where(auditLogEntity.menuUid.ne("SYSTEM"), menuNameEquals(searchValue))
.groupBy(auditLogEntity.menuUid) .groupBy(auditLogEntity.menuUid)
.offset(pageable.getOffset()) .offset(pageable.getOffset())
.limit(pageable.getPageSize()) .limit(pageable.getPageSize())
@@ -128,7 +138,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
.from(auditLogEntity) .from(auditLogEntity)
.leftJoin(memberEntity) .leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id)) .on(auditLogEntity.userUid.eq(memberEntity.id))
.where(loginIdOrUsernameContains(searchValue)) .where(auditLogEntity.userUid.isNotNull(), loginIdOrUsernameContains(searchValue))
.groupBy(auditLogEntity.userUid, memberEntity.employeeNo, memberEntity.name) .groupBy(auditLogEntity.userUid, memberEntity.employeeNo, memberEntity.name)
.offset(pageable.getOffset()) .offset(pageable.getOffset())
.limit(pageable.getPageSize()) .limit(pageable.getPageSize())
@@ -147,6 +157,62 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
return new PageImpl<>(foundContent, pageable, countQuery); return new PageImpl<>(foundContent, pageable, countQuery);
} }
@Override
public Page<AuditLogDto.DownloadRes> findDownloadLog(
AuditLogDto.searchReq searchReq, DownloadReq req) {
Pageable pageable = searchReq.toPageable();
BooleanBuilder whereBuilder = new BooleanBuilder();
whereBuilder.and(auditLogEntity.eventStatus.ne(EventStatus.valueOf("FAILED")));
whereBuilder.and(auditLogEntity.eventType.eq(EventType.valueOf("DOWNLOAD")));
// if (req.getMenuId() != null && !req.getMenuId().isEmpty()) {
// whereBuilder.and(auditLogEntity.menuUid.eq(req.getMenuId()));
// }
if (req.getUuid() != null) {
whereBuilder.and(auditLogEntity.requestUri.contains(req.getRequestUri()));
whereBuilder.and(auditLogEntity.downloadUuid.eq(req.getUuid()));
}
if (req.getSearchValue() != null && !req.getSearchValue().isEmpty()) {
whereBuilder.and(
memberEntity
.name
.contains(req.getSearchValue())
.or(memberEntity.employeeNo.contains(req.getSearchValue())));
}
List<AuditLogDto.DownloadRes> foundContent =
queryFactory
.select(
Projections.constructor(
AuditLogDto.DownloadRes.class,
memberEntity.name,
memberEntity.employeeNo,
auditLogEntity.createdDate.as("downloadDttm")))
.from(auditLogEntity)
.leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id))
.where(whereBuilder, createdDateBetween(req.getStartDate(), req.getEndDate()))
.offset(pageable.getOffset())
.limit(pageable.getPageSize())
.orderBy(auditLogEntity.createdDate.desc())
.fetch();
Long countQuery =
queryFactory
.select(auditLogEntity.userUid.countDistinct())
.from(auditLogEntity)
.leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id))
.where(whereBuilder, createdDateBetween(req.getStartDate(), req.getEndDate()))
.fetchOne();
return new PageImpl<>(foundContent, pageable, countQuery);
}
@Override @Override
public Page<AuditLogDto.DailyDetail> findLogByDailyResult( public Page<AuditLogDto.DailyDetail> findLogByDailyResult(
AuditLogDto.searchReq searchReq, LocalDate logDate) { AuditLogDto.searchReq searchReq, LocalDate logDate) {
@@ -176,6 +242,9 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
memberEntity.employeeNo.as("loginId"), memberEntity.employeeNo.as("loginId"),
menuEntity.menuNm.as("menuName"), menuEntity.menuNm.as("menuName"),
auditLogEntity.eventType.as("eventType"), auditLogEntity.eventType.as("eventType"),
Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"),
Projections.constructor( Projections.constructor(
AuditLogDto.LogDetail.class, AuditLogDto.LogDetail.class,
Expressions.constant("한국자산관리공사"), // serviceName Expressions.constant("한국자산관리공사"), // serviceName
@@ -184,7 +253,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
menuEntity.menuUrl.as("menuUrl"), menuEntity.menuUrl.as("menuUrl"),
menuEntity.description.as("menuDescription"), menuEntity.description.as("menuDescription"),
menuEntity.menuOrder.as("sortOrder"), menuEntity.menuOrder.as("sortOrder"),
menuEntity.isUse.as("used")))) menuEntity.isUse.as("used")))) // TODO
.from(auditLogEntity) .from(auditLogEntity)
.leftJoin(menuEntity) .leftJoin(menuEntity)
.on(auditLogEntity.menuUid.eq(menuEntity.menuUid)) .on(auditLogEntity.menuUid.eq(menuEntity.menuUid))
@@ -238,8 +307,8 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
AuditLogDto.MenuDetail.class, AuditLogDto.MenuDetail.class,
auditLogEntity.id.as("logId"), auditLogEntity.id.as("logId"),
Expressions.stringTemplate( Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate) "to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"), // ?? .as("logDateTime"),
memberEntity.name.as("userName"), memberEntity.name.as("userName"),
memberEntity.employeeNo.as("loginId"), memberEntity.employeeNo.as("loginId"),
auditLogEntity.eventType.as("eventType"), auditLogEntity.eventType.as("eventType"),
@@ -305,7 +374,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
AuditLogDto.UserDetail.class, AuditLogDto.UserDetail.class,
auditLogEntity.id.as("logId"), auditLogEntity.id.as("logId"),
Expressions.stringTemplate( Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate) "to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"), .as("logDateTime"),
menuEntity.menuNm.as("menuName"), menuEntity.menuNm.as("menuName"),
auditLogEntity.eventType.as("eventType"), auditLogEntity.eventType.as("eventType"),
@@ -349,12 +418,23 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
if (Objects.isNull(startDate) || Objects.isNull(endDate)) { if (Objects.isNull(startDate) || Objects.isNull(endDate)) {
return null; return null;
} }
LocalDateTime startDateTime = startDate.atStartOfDay(); ZoneId zoneId = ZoneId.of("Asia/Seoul");
LocalDateTime endDateTime = endDate.plusDays(1).atStartOfDay(); ZonedDateTime startDateTime = startDate.atStartOfDay(zoneId);
ZonedDateTime endDateTime = endDate.plusDays(1).atStartOfDay(zoneId);
return auditLogEntity return auditLogEntity
.createdDate .createdDate
.goe(ZonedDateTime.from(startDateTime)) .goe(startDateTime)
.and(auditLogEntity.createdDate.lt(ZonedDateTime.from(endDateTime))); .and(auditLogEntity.createdDate.lt(endDateTime));
}
private BooleanExpression createdDateBetween(LocalDate startDate, LocalDate endDate) {
if (startDate == null || endDate == null) {
return null;
}
ZonedDateTime start = startDate.atStartOfDay(ZONE);
ZonedDateTime endExclusive = endDate.plusDays(1).atStartOfDay(ZONE);
return auditLogEntity.createdDate.goe(start).and(auditLogEntity.createdDate.lt(endExclusive));
} }
private BooleanExpression menuNameEquals(String searchValue) { private BooleanExpression menuNameEquals(String searchValue) {
@@ -393,11 +473,11 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
} }
private BooleanExpression eventEndedAtEqDate(LocalDate logDate) { private BooleanExpression eventEndedAtEqDate(LocalDate logDate) {
StringExpression eventEndedDate = ZoneId zoneId = ZoneId.of("Asia/Seoul");
Expressions.stringTemplate("to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate); ZonedDateTime start = logDate.atStartOfDay(zoneId);
LocalDateTime comparisonDate = logDate.atStartOfDay(); ZonedDateTime end = logDate.plusDays(1).atStartOfDay(zoneId);
return eventEndedDate.eq(comparisonDate.toString()); return auditLogEntity.createdDate.goe(start).and(auditLogEntity.createdDate.lt(end));
} }
private BooleanExpression menuUidEq(String menuUid) { private BooleanExpression menuUidEq(String menuUid) {
@@ -410,7 +490,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> readCount() { private NumberExpression<Integer> readCount() {
return new CaseBuilder() return new CaseBuilder()
.when(auditLogEntity.eventType.eq(EventType.READ)) .when(auditLogEntity.eventType.in(EventType.LIST, EventType.DETAIL))
.then(1) .then(1)
.otherwise(0) .otherwise(0)
.sum(); .sum();
@@ -418,7 +498,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> cudCount() { private NumberExpression<Integer> cudCount() {
return new CaseBuilder() return new CaseBuilder()
.when(auditLogEntity.eventType.in(EventType.CREATE, EventType.UPDATE, EventType.DELETE)) .when(auditLogEntity.eventType.in(EventType.ADDED, EventType.MODIFIED, EventType.REMOVE))
.then(1) .then(1)
.otherwise(0) .otherwise(0)
.sum(); .sum();
@@ -426,7 +506,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> printCount() { private NumberExpression<Integer> printCount() {
return new CaseBuilder() return new CaseBuilder()
.when(auditLogEntity.eventType.eq(EventType.PRINT)) .when(auditLogEntity.eventType.eq(EventType.OTHER))
.then(1) .then(1)
.otherwise(0) .otherwise(0)
.sum(); .sum();

View File

@@ -8,29 +8,35 @@ import static com.kamco.cd.training.postgres.entity.QMenuEntity.menuEntity;
import com.kamco.cd.training.log.dto.ErrorLogDto; import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventStatus; import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType; import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.BooleanExpression; import com.querydsl.core.types.dsl.BooleanExpression;
import com.querydsl.core.types.dsl.Expressions; import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.StringExpression; import com.querydsl.core.types.dsl.StringExpression;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.ZoneId;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@Repository @Repository
@RequiredArgsConstructor public class ErrorLogRepositoryImpl extends QuerydslRepositorySupport
public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom { implements ErrorLogRepositoryCustom {
private final JPAQueryFactory queryFactory; private final JPAQueryFactory queryFactory;
private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)"); private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)");
public ErrorLogRepositoryImpl(JPAQueryFactory queryFactory) {
super(AuditLogEntity.class);
this.queryFactory = queryFactory;
}
@Override @Override
public Page<ErrorLogDto.Basic> findLogByError(ErrorLogDto.ErrorSearchReq searchReq) { public Page<ErrorLogDto.Basic> findLogByError(ErrorLogDto.ErrorSearchReq searchReq) {
Pageable pageable = searchReq.toPageable(); Pageable pageable = searchReq.toPageable();
@@ -52,7 +58,7 @@ public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom {
errorLogEntity.errorMessage.as("errorMessage"), errorLogEntity.errorMessage.as("errorMessage"),
errorLogEntity.stackTrace.as("errorDetail"), errorLogEntity.stackTrace.as("errorDetail"),
Expressions.stringTemplate( Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", errorLogEntity.createdDate))) "to_char({0}, 'YYYY-MM-DD HH24:MI:SS.FF3')", errorLogEntity.createdDate)))
.from(errorLogEntity) .from(errorLogEntity)
.leftJoin(auditLogEntity) .leftJoin(auditLogEntity)
.on(errorLogEntity.id.eq(auditLogEntity.errorLogUid)) .on(errorLogEntity.id.eq(auditLogEntity.errorLogUid))
@@ -94,12 +100,14 @@ public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom {
if (Objects.isNull(startDate) || Objects.isNull(endDate)) { if (Objects.isNull(startDate) || Objects.isNull(endDate)) {
return null; return null;
} }
LocalDateTime startDateTime = startDate.atStartOfDay();
LocalDateTime endDateTime = endDate.plusDays(1).atStartOfDay(); ZoneId zoneId = ZoneId.of("Asia/Seoul");
ZonedDateTime startDateTime = startDate.atStartOfDay(zoneId);
ZonedDateTime endDateTime = endDate.plusDays(1).atStartOfDay(zoneId);
return auditLogEntity return auditLogEntity
.createdDate .createdDate
.goe(ZonedDateTime.from(startDateTime)) .goe(startDateTime)
.and(auditLogEntity.createdDate.lt(ZonedDateTime.from(endDateTime))); .and(auditLogEntity.createdDate.lt(endDateTime));
} }
private BooleanExpression eventStatusEqFailed() { private BooleanExpression eventStatusEqFailed() {

View File

@@ -5,4 +5,6 @@ import java.util.Optional;
public interface ModelConfigRepositoryCustom { public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId); Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
Optional<ModelConfigDto.TransferBasic> findModelTransferConfigByModelId(Long modelId);
} }

View File

@@ -1,8 +1,12 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity; import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic; import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.kamco.cd.training.model.dto.ModelConfigDto.TransferBasic;
import com.kamco.cd.training.postgres.entity.QModelConfigEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional; import java.util.Optional;
@@ -34,4 +38,44 @@ public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
.where(modelConfigEntity.model.id.eq(modelId)) .where(modelConfigEntity.model.id.eq(modelId))
.fetchOne()); .fetchOne());
} }
@Override
public Optional<TransferBasic> findModelTransferConfigByModelId(Long modelId) {
QModelConfigEntity beforeConfig = new QModelConfigEntity("beforeConfig");
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
TransferBasic.class,
// ===== 현재 =====
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo,
// ===== before =====
beforeConfig.id,
beforeConfig.model.id,
beforeConfig.epochCount,
beforeConfig.trainPercent,
beforeConfig.validationPercent,
beforeConfig.testPercent,
beforeConfig.memo))
.from(modelConfigEntity)
.innerJoin(modelConfigEntity.model, modelMasterEntity)
// before 모델 조인
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeConfig)
.on(beforeConfig.model.id.eq(beforeMaster.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne());
}
} }

View File

@@ -1,8 +1,15 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity; import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.util.List; import java.util.List;
public interface ModelDatasetMappRepositoryCustom { public interface ModelDatasetMappRepositoryCustom {
List<ModelDatasetMappEntity> findByModelUid(Long modelId); List<ModelDatasetMappEntity> findByModelUid(Long modelId);
List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId);
List<ModelTrainLinkDto> findDatasetValPath(Long modelId);
List<ModelTrainLinkDto> findDatasetTestPath(Long modelId);
} }

View File

@@ -1,8 +1,16 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity; import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -22,4 +30,136 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
.where(modelDatasetMappEntity.modelUid.eq(modelId)) .where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch(); .fetch();
} }
@Override
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetObjEntity.targetClassCd,
datasetObjEntity.comparePath,
datasetObjEntity.targetPath,
datasetObjEntity.labelPath,
datasetObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetObjEntity)
.on(
datasetObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetValObjEntity.targetClassCd,
datasetValObjEntity.comparePath,
datasetValObjEntity.targetPath,
datasetValObjEntity.labelPath,
datasetValObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetValObjEntity)
.on(
datasetValObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetTestObjEntity.targetClassCd,
datasetTestObjEntity.comparePath,
datasetTestObjEntity.targetPath,
datasetTestObjEntity.labelPath,
datasetTestObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetTestObjEntity)
.on(
datasetTestObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
} }

View File

@@ -4,10 +4,12 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@@ -34,4 +36,10 @@ public interface ModelDetailRepositoryCustom {
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid); List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
ModelBestEpoch getModelTrainBestEpoch(UUID uuid); ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
ModelFileInfo getModelTrainFileInfo(UUID uuid);
List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid);
ModelMasterEntity findByModelBeforeId(Long beforeModelId);
} }

View File

@@ -9,20 +9,25 @@ import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.mode
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity; import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary; import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity; import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.jpa.JPAExpressions; import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@@ -55,6 +60,13 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
@Override @Override
public DetailSummary getModelDetailSummary(UUID uuid) { public DetailSummary getModelDetailSummary(UUID uuid) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
return queryFactory return queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
@@ -66,7 +78,8 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelMasterEntity.step1StrtDttm, modelMasterEntity.step1StrtDttm,
modelMasterEntity.step2EndDttm, modelMasterEntity.step2EndDttm,
modelMasterEntity.statusCd, modelMasterEntity.statusCd,
modelMasterEntity.trainType)) modelMasterEntity.trainType,
beforeModelUuid))
.from(modelMasterEntity) .from(modelMasterEntity)
.where(modelMasterEntity.uuid.eq(uuid)) .where(modelMasterEntity.uuid.eq(uuid))
.fetchOne(); .fetchOne();
@@ -269,4 +282,94 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch())) modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
.fetchOne(); .fetchOne();
} }
@Override
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return queryFactory
.select(
Projections.constructor(
ModelFileInfo.class,
modelMasterEntity
.packingState
.eq(TrainStatusType.COMPLETED.getId())
.coalesce(false),
modelMasterEntity.modelVer))
.from(modelMasterEntity)
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
ModelMasterEntity entity = findByModelByUUID(uuid);
if (entity == null) {
return List.of();
}
List<ModelProgressStepDto> steps = new ArrayList<>();
// 0단계 : 대기 상태
steps.add(
ModelProgressStepDto.builder()
.step(0)
.status(TrainStatusType.READY.getId())
.startTime(entity.getCreatedDttm())
.endTime(null)
.isError(false)
.build());
// 1단계 : Train/Validation 실행
boolean step1Active =
entity.getStep1StrtDttm() != null
&& !TrainStatusType.READY.getId().equals(entity.getStep1State());
if (step1Active) {
steps.add(
ModelProgressStepDto.builder()
.step(1)
.status(entity.getStep1State())
.startTime(entity.getStep1StrtDttm())
.endTime(entity.getStep1EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep1State()))
.build());
}
// 2단계 : Test 실행
boolean step2Done = entity.getStep2State() != null;
if (step2Done) {
steps.add(
ModelProgressStepDto.builder()
.step(2)
.status(entity.getStep2State())
.startTime(entity.getStep2StrtDttm())
.endTime(entity.getStep2EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep2State()))
.build());
}
// 3단계 : 패키징
boolean step3Done = entity.getPackingState() != null;
if (step3Done) {
steps.add(
ModelProgressStepDto.builder()
.step(3)
.status(entity.getPackingState())
.startTime(entity.getPackingStrtDttm())
.endTime(entity.getPackingEndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getPackingState()))
.build());
}
return steps;
}
@Override
public ModelMasterEntity findByModelBeforeId(Long beforeModelId) {
return queryFactory
.selectFrom(modelMasterEntity)
.where(modelMasterEntity.id.eq(beforeModelId))
.fetchOne();
}
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.util.Optional; import java.util.Optional;
@@ -15,7 +16,7 @@ public interface ModelMngRepositoryCustom {
* @param searchReq * @param searchReq
* @return * @return
*/ */
Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq searchReq); Page<ListDto> findByModels(ModelTrainMngDto.SearchReq searchReq);
Optional<ModelMasterEntity> findByUuid(UUID uuid); Optional<ModelMasterEntity> findByUuid(UUID uuid);

View File

@@ -1,14 +1,18 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity; import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity; import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity; import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder; import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions; import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
@@ -34,12 +38,23 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
* @return * @return
*/ */
@Override @Override
public Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq req) { public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
Pageable pageable = req.toPageable(); Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
if (req.getStatus() != null && !req.getStatus().isEmpty()) { if (req.getStatus() != null && !req.getStatus().isEmpty()) {
builder.and(modelMasterEntity.statusCd.eq(req.getStatus())); builder.and(
modelMasterEntity
.step1State
.eq(req.getStatus())
.or(modelMasterEntity.step2State.eq(req.getStatus())));
} }
if (req.getModelNo() != null && !req.getModelNo().isEmpty()) { if (req.getModelNo() != null && !req.getModelNo().isEmpty()) {
@@ -48,9 +63,37 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
builder.and(modelMasterEntity.delYn.isFalse()); builder.and(modelMasterEntity.delYn.isFalse());
List<ModelMasterEntity> content = List<ListDto> content =
queryFactory queryFactory
.selectFrom(modelMasterEntity) .select(
Projections.constructor(
ListDto.class,
modelMasterEntity.id,
modelMasterEntity.uuid,
modelMasterEntity.modelVer,
modelMasterEntity.strtDttm,
modelMasterEntity.step1StrtDttm,
modelMasterEntity.step1EndDttm,
modelMasterEntity.step2StrtDttm,
modelMasterEntity.step2EndDttm,
modelMasterEntity.step1State,
modelMasterEntity.step2State,
modelMasterEntity.statusCd,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelMasterEntity.currentAttemptId,
modelMasterEntity.requestPath,
modelMasterEntity.packingState,
modelMasterEntity.packingStrtDttm,
modelMasterEntity.packingEndDttm,
modelConfigEntity.memo,
memberEntity.name,
beforeModelUuid))
.from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))
.leftJoin(memberEntity)
.on(modelMasterEntity.createdUid.eq(memberEntity.id))
.where(builder) .where(builder)
.offset(pageable.getOffset()) .offset(pageable.getOffset())
.limit(pageable.getPageSize()) .limit(pageable.getPageSize())
@@ -62,6 +105,10 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
queryFactory queryFactory
.select(modelMasterEntity.count()) .select(modelMasterEntity.count())
.from(modelMasterEntity) .from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))
.leftJoin(memberEntity)
.on(modelMasterEntity.createdUid.eq(memberEntity.id))
.where(builder) .where(builder)
.fetchOne(); .fetchOne();
@@ -154,7 +201,11 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
return queryFactory return queryFactory
.select(modelMasterEntity.id.count()) .select(modelMasterEntity.id.count())
.from(modelMasterEntity) .from(modelMasterEntity)
.where(modelMasterEntity.step1State.eq(TrainStatusType.IN_PROGRESS.getId())) .where(
modelMasterEntity
.step1State
.eq(TrainStatusType.IN_PROGRESS.getId())
.or(modelMasterEntity.step2State.eq(TrainStatusType.IN_PROGRESS.getId())))
.fetchOne(); .fetchOne();
} }
} }

View File

@@ -1,6 +1,9 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.time.ZonedDateTime;
import java.util.List; import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom { public interface ModelTestMetricsJobRepositoryCustom {
@@ -10,4 +13,12 @@ public interface ModelTestMetricsJobRepositoryCustom {
List<ResponsePathDto> getTestMetricSaveNotYetModelIds(); List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs); void insertModelMetricsTest(List<Object[]> batchArgs);
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
ModelTestFileName findModelTestFileNames(Long modelId);
void updatePackingStart(Long modelId, ZonedDateTime now);
void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState);
} }

View File

@@ -1,12 +1,18 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.Properties;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZonedDateTime;
import java.util.List; import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport; import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.JdbcTemplate;
@@ -59,15 +65,107 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
@Override @Override
public void insertModelMetricsTest(List<Object[]> batchArgs) { public void insertModelMetricsTest(List<Object[]> batchArgs) {
String sql = // AS-IS
""" // String sql =
insert into tb_model_metrics_test // """
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou, // insert into tb_model_metrics_test
detection_count, gt_count // (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
) // detection_count, gt_count
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) // )
"""; // values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
// """;
//
// jdbcTemplate.batchUpdate(sql, batchArgs);
jdbcTemplate.batchUpdate(sql, batchArgs); // TO-BE: modelId, model(best_fscore_10) 같은 데이터가 있으면 update, 없으면 insert
String updateSql =
"""
UPDATE tb_model_metrics_test
SET tp=?, fp=?, fn=?, precisions=?, recall=?, f1_score=?, accuracy=?, iou=?,
detection_count=?, gt_count=?
WHERE model_id=? AND model=?
""";
String insertSql =
"""
INSERT INTO tb_model_metrics_test
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
detection_count, gt_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
// row 단위 처리 (batch 안에서 upsert)
for (Object[] row : batchArgs) {
// row 순서: (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
// detection_count, gt_count)
int updated =
jdbcTemplate.update(
updateSql, row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10],
row[11], row[0], row[1]);
if (updated == 0) {
jdbcTemplate.update(insertSql, row);
}
}
}
@Override
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelMetricJsonDto.class,
modelMasterEntity.modelNo,
modelMasterEntity.modelVer,
Projections.constructor(
Properties.class,
modelMetricsTestEntity.f1Score,
modelMetricsTestEntity.precisions,
modelMetricsTestEntity.recall,
modelMetricsTestEntity.iou,
modelMetricsTrainEntity.loss)))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.innerJoin(modelMetricsTrainEntity)
.on(
modelMetricsTestEntity.model.eq(modelMetricsTrainEntity.model),
modelMasterEntity.bestEpoch.eq(modelMetricsTrainEntity.epoch))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.orderBy(modelMetricsTestEntity.createdDttm.desc())
.fetchFirst();
}
@Override
public ModelTestFileName findModelTestFileNames(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelTestFileName.class, modelMetricsTestEntity.model1, modelMasterEntity.modelVer))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.fetchOne();
}
@Override
public void updatePackingStart(Long modelId, ZonedDateTime now) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.packingStrtDttm, ZonedDateTime.now())
.set(modelMasterEntity.packingState, TrainStatusType.READY.getId())
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.packingEndDttm, ZonedDateTime.now())
.set(modelMasterEntity.packingState, failSuccState)
.where(modelMasterEntity.id.eq(modelId))
.execute();
} }
} }

View File

@@ -1,10 +1,17 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.List;
import java.util.Optional; import java.util.Optional;
public interface ModelTrainJobRepositoryCustom { public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId); int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId); Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
List<ModelTrainJobEntity> findRunningJobs();
} }

View File

@@ -1,9 +1,15 @@
package com.kamco.cd.training.postgres.repository.train; package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.JobType;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity; import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManager;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@@ -19,7 +25,7 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
/** modelId의 attempt_no 최대값. (없으면 0) */ /** modelId의 attempt_no 최대값. (없으면 0) */
@Override @Override
public int findMaxAttemptNo(Long modelId) { public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; QModelTrainJobEntity j = modelTrainJobEntity;
Integer max = Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne(); queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
@@ -33,11 +39,61 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
*/ */
@Override @Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) { public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity; QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job = ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst(); queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
return Optional.ofNullable(job); return Optional.ofNullable(job);
} }
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
.selectFrom(j)
.where(j.containerName.eq(containerName))
.orderBy(j.id.desc())
.fetchFirst();
return Optional.ofNullable(job);
}
@Override
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
Integer maxAttemptNo =
queryFactory
.select(modelTestTrainingRunEntity.attemptNo.max().coalesce(0))
.from(modelTestTrainingRunEntity)
.where(modelTestTrainingRunEntity.modelId.eq(modelId))
.fetchOne();
int nextAttemptNo = (maxAttemptNo == null ? 1 : maxAttemptNo + 1);
queryFactory
.insert(modelTestTrainingRunEntity)
.columns(
modelTestTrainingRunEntity.modelId,
modelTestTrainingRunEntity.attemptNo,
modelTestTrainingRunEntity.jobId,
modelTestTrainingRunEntity.epoch)
.values(modelId, nextAttemptNo, jobId, epoch)
.execute();
}
@Override
public List<ModelTrainJobEntity> findRunningJobs() {
return queryFactory
.select(modelTrainJobEntity)
.from(modelTrainJobEntity)
.where(
modelTrainJobEntity
.statusCd
.eq(JobStatusType.RUNNING.getId())
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
.orderBy(modelTrainJobEntity.id.desc())
.fetch();
}
} }

View File

@@ -12,4 +12,6 @@ public interface ModelTrainMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo); void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
void insertModelMetricsValidation(List<Object[]> batchArgs); void insertModelMetricsValidation(List<Object[]> batchArgs);
void updateModelSelectedBestEpoch(Long modelId, Integer epoch);
} }

View File

@@ -82,4 +82,13 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
jdbcTemplate.batchUpdate(sql, batchArgs); jdbcTemplate.batchUpdate(sql, batchArgs);
} }
@Override
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.bestEpoch, epoch)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
} }

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train; package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.DataSetCountersService;
import com.kamco.cd.training.train.service.TestJobService; import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService; import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@@ -12,6 +13,8 @@ import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
@@ -25,6 +28,7 @@ public class TrainApiController {
private final TrainJobService trainJobService; private final TrainJobService trainJobService;
private final TestJobService testJobService; private final TestJobService testJobService;
private final DataSetCountersService dataSetCountersService;
@Operation(summary = "학습 실행", description = "학습 실행 API") @Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses( @ApiResponses(
@@ -46,6 +50,7 @@ public class TrainApiController {
UUID uuid) { UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid); Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.enqueue(modelId); trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok"); return ApiResponseDto.ok("ok");
} }
@@ -180,10 +185,32 @@ public class TrainApiController {
}) })
@PostMapping("/create-tmp/{uuid}") @PostMapping("/create-tmp/{uuid}")
public ApiResponseDto<UUID> createTmpFile( public ApiResponseDto<UUID> createTmpFile(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d") @Parameter(description = "model uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable @PathVariable
UUID uuid) { UUID uuid) {
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid)); return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
} }
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping(path = "/counts/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE)
public ApiResponseDto<String> getCount(
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
}
} }

View File

@@ -13,4 +13,10 @@ public class EvalRunRequest {
private String uuid; private String uuid;
private int epoch; // best_changed_fscore_epoch_1.pth private int epoch; // best_changed_fscore_epoch_1.pth
private Integer timeoutSeconds; private Integer timeoutSeconds;
private String datasetFolder;
private String outputFolder;
public String getOutputFolder() {
return this.outputFolder.toString();
}
} }

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class ModelTrainLinkDto {
private Long modelId;
private String trainType;
private String modelNo;
private Long datasetId;
private String targetClassCd;
private String comparePath;
private String targetPath;
private String labelPath;
private String geoJsonPath;
private String datasetUid;
}

View File

@@ -1,6 +1,8 @@
package com.kamco.cd.training.train.dto; package com.kamco.cd.training.train.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Properties;
import java.util.UUID; import java.util.UUID;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
@@ -20,4 +22,38 @@ public class ModelTrainMetricsDto {
private String responsePath; private String responsePath;
private UUID uuid; private UUID uuid;
} }
@Getter
@AllArgsConstructor
public static class ModelMetricJsonDto {
@JsonProperty("cd_model_type")
private String cdModelType;
@JsonProperty("model_version")
private String modelVersion;
private Properties properties;
}
@Getter
@AllArgsConstructor
public static class Properties {
@JsonProperty("f1_score")
private Float f1Score;
private Float precision;
private Float recall;
private Float loss;
private Double iou;
}
@Getter
@AllArgsConstructor
public static class ModelTestFileName {
private String bestEpochFileName;
private String modelVersion;
}
} }

View File

@@ -0,0 +1,228 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
/** 학습실행 파일 하드링크 */
@Service
@Log4j2
@RequiredArgsConstructor
public class DataSetCountersService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
@Value("${train.docker.requestDir}")
private String requestDir;
@Value("${train.docker.basePath}")
private String trainBaseDir;
public String getCount(Long modelId) {
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
StringBuilder allLogs = new StringBuilder();
try {
// request 폴더
for (String uid : uids) {
Path path = Path.of(requestDir, uid);
DatasetCounters counters = countTmpAfterBuild(path);
allLogs.append(counters.prints(uid, "REQUEST")).append(System.lineSeparator());
}
// tmp
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
// 차이나는거
diffMergedRequestsVsTmp(uids, tmpPath);
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
allLogs
.append(counters2.prints(basic.getRequestPath(), "TMP"))
.append(System.lineSeparator());
} catch (IOException e) {
log.error(e.getMessage());
}
return allLogs.toString();
}
private int countTif(Path dir) throws IOException {
if (!Files.isDirectory(dir)) return 0;
try (var stream = Files.walk(dir)) {
return (int)
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
}
/*
대소문자 및 geojson 필요시
* try (var stream = Files.walk(dir)) {
return (int)
stream
.filter(Files::isRegularFile)
.filter(p -> {
String name = p.getFileName().toString().toLowerCase();
return name.endsWith(".tif") || name.endsWith(".geojson");
})
.count();
}
* */
}
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
// input1
int in1Train = countTif(path.resolve("train/input1"));
int in1Val = countTif(path.resolve("val/input1"));
int in1Test = countTif(path.resolve("test/input1"));
// input2
int in2Train = countTif(path.resolve("train/input2"));
int in2Val = countTif(path.resolve("val/input2"));
int in2Test = countTif(path.resolve("test/input2"));
List<DatasetCounter> input1List = new ArrayList<>();
List<DatasetCounter> input2List = new ArrayList<>();
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
return new DatasetCounters(input1List, input2List);
}
@Getter
public static class DatasetCounter {
private int inputTrain = 0;
private int inputTest = 0;
private int inputVal = 0;
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
this.inputTrain = inputTrain;
this.inputTest = inputTest;
this.inputVal = inputVal;
}
}
@Getter
public static class DatasetCounters {
private List<DatasetCounter> input1 = new ArrayList<>();
private List<DatasetCounter> input2 = new ArrayList<>();
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
this.input1 = input1;
this.input2 = input2;
}
public String prints(String uuid, String type) {
int train = 0, test = 0, val = 0;
int train2 = 0, test2 = 0, val2 = 0;
for (DatasetCounter datasetCounter : input1) {
train += datasetCounter.inputTrain;
test += datasetCounter.inputTest;
val += datasetCounter.inputVal;
}
for (DatasetCounter datasetCounter : input2) {
train2 += datasetCounter.inputTrain;
test2 += datasetCounter.inputTest;
val2 += datasetCounter.inputVal;
}
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
log.info(
"*total* = train : {} | val : {} | test : {} ",
train + train2,
val + val2,
test + test2);
return String.format(
"======== UUID FOLDER COUNT %s : %s%n"
+ "input 1 = train : %s | val : %s | test : %s%n"
+ "input 2 = train : %s | val : %s | test : %s%n"
+ "*total* = train : %s | val : %s | test : %s",
type,
uuid,
train,
val,
test,
train2,
val2,
test2,
train + train2,
val + val2,
test + test2);
}
}
private Set<String> listTifRelative(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> root.relativize(p).toString().replace("\\", "/"))
.collect(Collectors.toSet());
}
}
private Set<String> listTifFileNameOnly(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> p.getFileName().toString()) // 파일명만
.collect(Collectors.toSet());
}
}
public void diffMergedRequestsVsTmp(List<String> uids, Path tmpRoot) throws IOException {
// 1) 요청 uids 전체를 합친 tif "파일명" 집합
Set<String> reqAll = new HashSet<>();
for (String uid : uids) {
Path reqRoot = Path.of(requestDir, uid);
// ★합본 tmp는 보통 폴더 구조가 바뀌므로 "상대경로" 비교보다 파일명 비교가 먼저 유용합니다.
reqAll.addAll(listTifFileNameOnly(reqRoot));
}
// 2) tmp tif 파일명 집합
Set<String> tmpAll = listTifFileNameOnly(tmpRoot);
Set<String> missing = new HashSet<>(reqAll);
missing.removeAll(tmpAll);
Set<String> extra = new HashSet<>(tmpAll);
extra.removeAll(reqAll);
log.info("==== MERGED DIFF (filename-based) ====");
log.info("request(all uids) tif = {}", reqAll.size());
log.info("tmp tif = {}", tmpAll.size());
log.info("missing = {}", missing.size());
log.info("extra = {}", extra.size());
missing.stream().sorted().limit(50).forEach(f -> log.warn("[MISSING] {}", f));
extra.stream().sorted().limit(50).forEach(f -> log.warn("[EXTRA] {}", f));
}
}

View File

@@ -1,22 +1,29 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest; import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult; import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2; import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@Log4j2 @Log4j2
@Service @Service
@RequiredArgsConstructor
public class DockerTrainService { public class DockerTrainService {
// 실행할 Docker 이미지명 // 실행할 Docker 이미지명
@@ -43,7 +50,16 @@ public class DockerTrainService {
@Value("${train.docker.ipcHost:true}") @Value("${train.docker.ipcHost:true}")
private boolean ipcHost; private boolean ipcHost;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */ private final ModelTrainJobCoreService modelTrainJobCoreService;
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param req
* @param containerName
* @return
* @throws Exception
*/
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception { public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req); List<String> cmd = buildDockerRunCommand(containerName, req);
@@ -56,12 +72,42 @@ public class DockerTrainService {
ProcessBuilder pb = new ProcessBuilder(cmd); ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true); pb.redirectErrorStream(true);
Process p = pb.start();
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder logBuilder = new StringBuilder(); StringBuilder logBuilder = new StringBuilder();
Process p = pb.start();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b"); log.info("[TRAIN-BOOT] docker run started. container={}", containerName);
try {
log.info("[TRAIN-BOOT] pid={}", p.pid()); // Java 9+
} catch (Throwable ignore) {
}
try {
// 바로 죽었는지 100ms만 체크
if (p.waitFor(100, TimeUnit.MILLISECONDS)) {
int exit = p.exitValue();
String earlyLogs;
synchronized (logBuilder) {
earlyLogs = logBuilder.toString();
}
log.error(
"[TRAIN-BOOT] docker run exited immediately. container={} exit={}",
containerName,
exit);
log.error("[TRAIN-BOOT] early logs:\n{}", earlyLogs);
} else {
log.info("[TRAIN-BOOT] docker run is still running. container={}", containerName);
}
} catch (Exception e) {
log.warn("[TRAIN-BOOT] early-exit check failed: {}", e.toString(), e);
}
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
// 너무 잦은 업데이트 방지용
AtomicInteger lastEpoch = new AtomicInteger(0);
AtomicInteger lastIter = new AtomicInteger(0);
Thread logThread = Thread logThread =
new Thread( new Thread(
@@ -73,23 +119,40 @@ public class DockerTrainService {
String line; String line;
while ((line = br.readLine()) != null) { while ((line = br.readLine()) != null) {
// 1) 로그 누적
synchronized (logBuilder) { synchronized (logBuilder) {
logBuilder.append(line).append('\n'); logBuilder.append(line).append('\n');
} }
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line); Matcher m = epochPattern.matcher(line);
if (m.find()) { if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1)); int epoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2)); int iter = Integer.parseInt(m.group(2));
int totalIter = Integer.parseInt(m.group(3));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch); // (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
// TODO 실행중인 에폭 저장 필요하면 만들어야함 // 쓰로틀링: 에폭 끝 or 10 iter마다
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..? boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch); // 중복 방지
if (shouldUpdate) {
int prevEpoch = lastEpoch.get();
int prevIter = lastIter.get();
if (epoch != prevEpoch || iter != prevIter) {
lastEpoch.set(epoch);
lastIter.set(iter);
log.info(
"[TRAIN] container={} epoch={} iter={}/{}",
containerName,
epoch,
iter,
totalIter);
modelTrainJobCoreService.updateEpoch(containerName, epoch);
}
}
} }
} }
} catch (Exception e) { } catch (Exception e) {
@@ -97,21 +160,6 @@ public class DockerTrainService {
} }
}, },
"train-log-" + containerName); "train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true); logThread.setDaemon(true);
logThread.start(); logThread.start();
@@ -206,7 +254,7 @@ public class DockerTrainService {
// 요청/결과 디렉토리 볼륨 마운트 // 요청/결과 디렉토리 볼륨 마운트
c.add("-v"); c.add("-v");
c.add(requestDir + "/tmp:/data"); c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v"); c.add("-v");
c.add(responseDir + ":/checkpoints"); c.add(responseDir + ":/checkpoints");
@@ -226,8 +274,7 @@ public class DockerTrainService {
addArg(c, "--input-size", req.getInputSize()); addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize()); addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize()); addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds()); addArg(c, "--gpu-ids", req.getGpuIds()); // null
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--lr", req.getLearningRate()); addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone()); addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs()); addArg(c, "--epochs", req.getEpochs());
@@ -264,15 +311,20 @@ public class DockerTrainService {
// ===== Augmentation ===== // ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb()); addArg(c, "--rot-prob", req.getRotProb());
// addArg(c, "--rot-degree", req.getRotDegree()); // TODO AI 수정되면 주석 해제 addArg(c, "--rot-degree", req.getRotDegree());
addArg(c, "--flip-prob", req.getFlipProb()); addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb()); addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta()); addArg(c, "--brightness-delta", req.getBrightnessDelta());
// addArg(c, "--contrast-range", req.getContrastRange()); // TODO AI 수정되면 주석 해제 addArg(c, "--contrast-range", req.getContrastRange());
// addArg(c, "--saturation-range", req.getSaturationRange()); // TODO AI 수정되면 주석 해제 addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta()); addArg(c, "--hue-delta", req.getHueDelta());
addArg(c, "--resume-from", req.getResumeFrom()); if (req.getResumeFrom() != null && !req.getResumeFrom().isBlank()) {
c.add("--resume");
addArg(c, "--load-from", req.getResumeFrom());
}
addArg(c, "--save-interval", 1);
return c; return c;
} }
@@ -296,7 +348,15 @@ public class DockerTrainService {
} }
} }
public TrainRunResult runEvalSync(EvalRunRequest req, String containerName) throws Exception { /**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param containerName
* @param req
* @return
* @throws Exception
*/
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req); List<String> cmd = buildDockerEvalCommand(containerName, req);
@@ -341,7 +401,6 @@ public class DockerTrainService {
synchronized (log) { synchronized (log) {
logs = log.toString(); logs = log.toString();
} }
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs); return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
} }
@@ -370,37 +429,58 @@ public class DockerTrainService {
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required"); if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0"); if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
String modelFile = "best_changed_fscore_epoch_" + epoch + ".pth"; Path epochPath = Paths.get(responseDir, req.getOutputFolder());
// 결과 폴더에 파라미터로 받은 베스트 epoch이 best_changed_fscore_epoch_ 로 시작하는 파일이 있는지 확인 후 pth 파일명 반환
String modelFile = findCheckpoint(epochPath, epoch);
List<String> c = new ArrayList<>(); List<String> c = new ArrayList<>();
c.add("docker"); c.add("docker");
c.add("run"); c.add("run");
c.add("--name");
c.add(containerName);
c.add("--rm"); c.add("--rm");
c.add("--gpus"); c.add("--gpus");
c.add("all"); c.add("all");
if (ipcHost) c.add("--ipc=host"); c.add("--ipc=host");
c.add("--shm-size=" + shmSize); c.add("--shm-size=" + shmSize);
c.add("-v"); c.add("-v");
c.add(requestDir + ":/data"); c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v"); c.add("-v");
c.add(responseDir + ":/checkpoints"); c.add(responseDir + ":/checkpoints");
c.add(image); c.add("kamco-cd-train:latest");
c.add("python"); c.add("python");
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py"); c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
c.add("--dataset_dir"); addArg(c, "--dataset-folder", req.getDatasetFolder());
c.add("/data/" + uuid); addArg(c, "--output-folder", req.getOutputFolder());
c.add("--model"); c.add("--epoch");
c.add("/checkpoints/" + uuid + "/" + modelFile); c.add(modelFile);
return c; return c;
} }
public String findCheckpoint(Path dir, int epoch) {
String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch);
String normalFileName = String.format("epoch_%d.pth", epoch);
Path bestPath = dir.resolve(bestFileName);
Path normalPath = dir.resolve(normalFileName);
// 1. best 파일이 존재하면 그거 사용
if (Files.isRegularFile(bestPath)) {
return bestFileName;
}
// 2. 없으면 일반 epoch 파일 사용
if (Files.isRegularFile(normalPath)) {
return normalFileName;
}
throw new IllegalStateException("Checkpoint 파일이 없습니다. epoch=" + epoch);
}
} }

View File

@@ -0,0 +1,449 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
/**
* 서버 재기동 시 "RUNNING 상태로 남아있는 학습 Job"을 복구(정리)하기 위한 서비스.
*
* <p>상황 예시: - 서버가 강제 재기동/장애로 내려감 - DB 상에서는 job_state가 RUNNING(진행중)으로 남아있음 - 실제 docker 컨테이너는: 1) 아직
* 살아있거나(running=true) 2) 종료되었거나(exited) 3) --rm 옵션으로 인해 컨테이너가 이미 삭제되어 존재하지 않을 수 있음
*
* <p>이 클래스는 ApplicationReadyEvent(스프링 부팅 완료) 시점에 실행되어, DB의 RUNNING 잡들을 조회한 뒤 컨테이너 상태를 점검하고,
* SUCCESS/FAILED 처리를 수행합니다.
*/
@Profile("!local")
@Component
@RequiredArgsConstructor
@Log4j2
public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
/**
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
*
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
*/
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
*
* <p>@Transactional: - recover() 메서드 전체가 하나의 트랜잭션으로 감싸집니다. - Job 하나씩 처리하다가 예외가 발생하면 전체 롤백이 될 수
* 있으므로 "잡 단위로 확실히 커밋"이 필요하면 (권장) 잡 단위로 분리 트랜잭션(REQUIRES_NEW) 고려하세요.
*/
// @EventListener(ApplicationReadyEvent.class)
@Transactional
public void recover() {
// 1) DB에서 "RUNNING(진행중) 상태"로 남아있는 job 목록을 조회
List<ModelTrainJobDto> runningJobs = modelTrainJobCoreService.findRunningJobs();
// 실행중 job이 없으면 할 일 없음
if (runningJobs == null || runningJobs.isEmpty()) {
return;
}
// 2) 각 job에 대해 docker 컨테이너 상태를 확인하고, 상태에 따라 조치
for (ModelTrainJobDto job : runningJobs) {
String containerName = job.getContainerName();
try {
// 2-1) docker inspect로 컨테이너 상태 조회
DockerInspectState state = inspectContainer(containerName);
// 3) 컨테이너가 "없음"
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
if (!state.exists()) {
log.warn(
"[RECOVERY] container missing. try file-based reconcile. container={}",
containerName);
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
OutputResult out = probeOutputs(job);
// 3-2) 산출물이 충분하면 성공 처리
if (out.completed()) {
log.info("[RECOVERY] outputs look completed. mark SUCCESS. jobId={}", job.getId());
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 3-3) 산출물이 부족하면 실패 처리(운영 정책에 따라 "유예"도 가능)
log.warn(
"[RECOVERY] outputs incomplete. mark FAILED. jobId={} reason={}",
job.getId(),
out.reason());
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
markStepErrorByJobType(job, out.reason());
}
continue;
}
// 4) 컨테이너는 존재하고, 아직 running=true
// - 서버만 재기동됐고 컨테이너는 그대로 살아있는 케이스
// - 이 경우 DB를 건드리면 오히려 꼬일 수 있으니 RUNNING 유지
if (state.running()) {
log.info("[RECOVERY] container still running. container={}", containerName);
try {
ProcessBuilder pb = new ProcessBuilder("docker", "stop", "-t", "20", containerName);
pb.redirectErrorStream(true);
Process p = pb.start();
boolean finished = p.waitFor(30, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker stop timeout");
}
int code = p.exitValue();
if (code != 0) {
throw new IOException("docker stop failed. exit=" + code);
}
log.info(
"[RECOVERY] container stopped (will be auto removed by --rm). container={}",
containerName);
// 여기서 상태를 PAUSED로 바꿔도 되고
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
} catch (Exception e) {
log.error("[RECOVERY] docker stop failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
}
continue;
}
// 5) 컨테이너는 존재하지만 running=false
// - exited / dead 등의 상태
Integer exitCode = state.exitCode();
String status = state.status();
// 5-1) exitCode=0이면 정상 종료로 간주 → SUCCESS 처리
if (exitCode != null && exitCode == 0) {
log.info("[RECOVERY] container exited(0). mark SUCCESS. container={}", containerName);
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 5-2) exitCode != 0 이거나 null이면 실패로 간주 → FAILED 처리
log.warn(
"[RECOVERY] container exited non-zero. mark FAILED. container={} status={} exitCode={}",
containerName,
status,
exitCode);
modelTrainJobCoreService.markFailed(
job.getId(), exitCode, "SERVER_RESTART_CONTAINER_EXIT_NONZERO");
markStepErrorByJobType(job, "exit=" + exitCode + " status=" + status);
}
} catch (Exception e) {
// 6) docker inspect 자체가 실패한 경우
// - docker 데몬 문제/권한 문제/일시적 오류 가능
// - 운영 정책에 따라 "바로 실패" 대신 "유예" 처리도 고려 가능
log.error("[RECOVERY] container inspect failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_INSPECT_ERROR");
markStepErrorByJobType(job, "inspect-error");
}
}
}
/**
* jobType에 따라 학습 관리 테이블의 "성공 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 성공 - 그 외 → step1(학습 단계) 성공
*/
private void markStepSuccessByJobType(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Success(job.getModelId());
} else {
modelTrainMngCoreService.markStep1Success(job.getModelId());
}
}
/**
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
*/
private void markStepErrorByJobType(ModelTrainJobDto job, String msg) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Error(job.getModelId(), msg);
} else {
modelTrainMngCoreService.markError(job.getModelId(), msg);
}
}
/**
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
*
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
*
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
*
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
* 타임아웃을 둠
*/
private DockerInspectState inspectContainer(String containerName)
throws IOException, InterruptedException {
ProcessBuilder pb =
new ProcessBuilder(
"docker",
"inspect",
"-f",
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
containerName);
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
pb.redirectErrorStream(true);
Process p = pb.start();
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
String output;
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
output = br.readLine();
}
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker inspect timeout");
}
// docker inspect 자체의 프로세스 exit code
int code = p.exitValue();
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
if (code != 0 || output == null || output.isBlank()) {
return DockerInspectState.missing();
}
// "status running exitCode" 형태로 split
String[] parts = output.trim().split("\\s+");
// status: running/exited/dead 등
String status = parts.length > 0 ? parts[0] : "unknown";
// running: true/false
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
// exitCode: 정수 파싱(파싱 실패하면 null)
Integer exitCode = null;
if (parts.length > 2) {
try {
exitCode = Integer.parseInt(parts[2]);
} catch (Exception ignore) {
// ignore
}
}
return new DockerInspectState(true, running, exitCode, status);
}
/**
* docker inspect 결과를 담는 레코드.
*
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
*/
private record DockerInspectState(
boolean exists, boolean running, Integer exitCode, String status) {
static DockerInspectState missing() {
return new DockerInspectState(false, false, null, "missing");
}
}
// ============================================================================================
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
*
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
*
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
* 실패로 오판할 수 있음.
*/
private OutputResult probeOutputs(ModelTrainJobDto job) {
try {
Path outDir = resolveOutputDir(job);
if (outDir == null || !Files.isDirectory(outDir)) {
return new OutputResult(false, "output-dir-missing");
}
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
if (totalEpoch == null || totalEpoch <= 0) {
return new OutputResult(false, "total-epoch-missing");
}
Path valCsv = outDir.resolve("val.csv");
if (!Files.exists(valCsv)) {
return new OutputResult(false, "val.csv-missing");
}
long lines = countNonHeaderLines(valCsv);
// “같아야 완료” 정책
if (lines == totalEpoch) {
return new OutputResult(true, "ok");
}
return new OutputResult(
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
} catch (Exception e) {
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
return new OutputResult(false, "probe-error");
}
}
/**
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
*
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
*
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
*
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
* 형태를 1순위로 두세요.
*/
private Path resolveOutputDir(ModelTrainJobDto job) {
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
return Files.isDirectory(base) ? base : null;
}
/**
* paramsJson에서 total_epoch 값을 추출합니다.
*
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
*
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
*/
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("total_epoch");
if (v == null) v = params.get("totalEpoch");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
/**
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
*
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
*
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
*/
private long countNonHeaderLines(Path csv) throws IOException {
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
}
}
/**
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
*
* <p>예: - "*.pth" - "best*.pth"
*/
private long countFilesByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
long cnt = 0;
for (Path p : ds) {
if (Files.isRegularFile(p)) cnt++;
}
return cnt;
}
}
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
private boolean existsByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
return ds.iterator().hasNext();
}
}
// ============================================================================================
// probeOutputs() 결과 객체
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
*
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
*
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
*/
private static final class OutputResult {
private final boolean completed;
private final String reason;
private OutputResult(boolean completed, String reason) {
this.completed = completed;
this.reason = reason;
}
boolean completed() {
return completed;
}
String reason() {
return reason;
}
}
}

View File

@@ -1,14 +1,28 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto; import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.ZonedDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVFormat;
@@ -31,20 +45,11 @@ public class ModelTestMetricsJobService {
@Value("${train.docker.responseDir}") @Value("${train.docker.responseDir}")
private String responseDir; private String responseDir;
/** @Value("${file.pt-path}")
* 실행중인 profile private String ptPathDir;
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *") /** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() { public void findTestValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds = List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds(); modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
@@ -96,11 +101,116 @@ public class ModelTestMetricsJobService {
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs); modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
// test.csv 파일 읽어서 저장한 여부로만 사용하기
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step2");
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2"); // 패키징할 파일 만들기
modelTestMetricsJobCoreService.updatePackingStart(
modelInfo.getModelId(), ZonedDateTime.now());
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir
+ "/"
+ modelInfo.getUuid()
+ "/"
+ jsonDto.getModelVersion()
+ ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
zipFiles(files, zipPath);
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
} catch (IOException e) {
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e);
}
}
}
private void writeJsonFile(Object data, Path outputPath) throws IOException {
Path parent = outputPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
try (OutputStream os =
Files.newOutputStream(
outputPath, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING)) {
objectMapper.writeValue(os, data);
}
}
private void zipFiles(List<Path> files, Path zipPath) throws IOException {
Path parent = zipPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(zipPath))) {
for (Path file : files) {
ZipEntry entry = new ZipEntry(file.getFileName().toString());
zos.putNextEntry(entry);
Files.copy(file, zos);
zos.closeEntry();
}
} }
} }
} }

View File

@@ -6,9 +6,14 @@ import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVFormat;
@@ -31,20 +36,8 @@ public class ModelTrainMetricsJobService {
@Value("${train.docker.responseDir}") @Value("${train.docker.responseDir}")
private String responseDir; private String responseDir;
/** /** 결과 csv 파일 정보 등록 */
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTrainValidMetricCsvFiles() { public void findTrainValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds = List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds(); modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
@@ -65,7 +58,7 @@ public class ModelTrainMetricsJobService {
for (CSVRecord record : parser) { for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch")) + 1; // TODO : 나중에 AI 개발 완료되면 -1 하기 int epoch = Integer.parseInt(record.get("Epoch"));
long iteration = Long.parseLong(record.get("Iteration")); long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss")); double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR")); double LR = Double.parseDouble(record.get("LR"));
@@ -129,6 +122,33 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
Integer epoch = null;
boolean exists;
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
try (Stream<Path> s = Files.list(responsePath)) {
epoch =
s.filter(Files::isRegularFile)
.map(
p -> {
Matcher matcher = pattern.matcher(p.getFileName().toString());
if (matcher.matches()) {
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn( modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1"); modelInfo.getModelId(), "step1");
} }

View File

@@ -1,10 +1,10 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
@@ -20,22 +20,38 @@ public class TestJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService; private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher; private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @param uuid
* @param epoch
* @return
*/
@Transactional @Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) { public Long enqueue(Long modelId, UUID uuid, int epoch) {
// 마스터 확인 // 마스터 확인
modelTrainMngCoreService.findModelById(modelId); modelTrainMngCoreService.findModelById(modelId);
// 폴더 카운트
dataSetCounters.getCount(modelId);
// best epoch 업데이트 // best epoch 업데이트
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch); modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
// 파라미터 조회
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
Map<String, Object> params = new java.util.LinkedHashMap<>(); Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL"); params.put("jobType", "EVAL");
params.put("uuid", String.valueOf(uuid)); params.put("uuid", String.valueOf(uuid));
params.put("epoch", epoch); params.put("epoch", epoch);
params.put("datasetFolder", trainRunRequest.getDatasetFolder());
params.put("outputFolder", trainRunRequest.getOutputFolder());
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
@@ -43,13 +59,18 @@ public class TestJobService {
modelTrainJobCoreService.createQueuedJob( modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, params, ZonedDateTime.now()); modelId, nextAttemptNo, params, ZonedDateTime.now());
// step2 시작으로 마킹 // test training run 테이블에 적재하기
modelTrainMngCoreService.markStep2InProgress(modelId, jobId); modelTrainJobCoreService.insertModelTestTrainingRun(modelId, jobId, epoch);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId)); eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId; return jobId;
} }
/**
* 취소
*
* @param modelId
*/
@Transactional @Transactional
public void cancel(Long modelId) { public void cancel(Long modelId) {

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.model.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.io.IOException; import java.io.IOException;
import java.nio.file.*; import java.nio.file.*;
import java.util.List; import java.util.List;
@@ -16,6 +17,98 @@ public class TmpDatasetService {
@Value("${train.docker.requestDir}") @Value("${train.docker.requestDir}")
private String requestDir; private String requestDir;
@Value("${train.docker.basePath}")
private String trainBaseDir;
/**
* train, val, test 폴더별로 link
*
* @param uid 임시폴더 uuid
* @param type train, val, test
* @param links tif pull path
* @return
* @throws IOException
*/
public void buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
throws IOException {
if (links == null || links.isEmpty()) {
throw new IOException("links is empty");
}
Path tmp = Path.of(trainBaseDir, "tmp", uid);
long hardlinksMade = 0;
for (ModelTrainLinkDto dto : links) {
if (type == null) {
log.warn("SKIP - trainType null: {}", dto);
continue;
}
// type별 디렉토리 생성
Files.createDirectories(tmp.resolve(type).resolve("input1"));
Files.createDirectories(tmp.resolve(type).resolve("input2"));
Files.createDirectories(tmp.resolve(type).resolve("label"));
Files.createDirectories(tmp.resolve(type).resolve("label-json"));
// comparePath input1
hardlinksMade += link(tmp, type, "input1", dto.getComparePath());
// targetPath input2
hardlinksMade += link(tmp, type, "input2", dto.getTargetPath());
// labelPath label
hardlinksMade += link(tmp, type, "label", dto.getLabelPath());
// geoJsonPath -> label-json
hardlinksMade += link(tmp, type, "label-json", dto.getGeoJsonPath());
}
if (hardlinksMade == 0) {
throw new IOException("No hardlinks created.");
}
log.info("tmp dataset created: {}, hardlinksMade={}", tmp, hardlinksMade);
}
private long link(Path tmp, String type, String part, String fullPath) throws IOException {
if (fullPath == null || fullPath.isBlank()) return 0;
Path src = Path.of(fullPath);
if (!Files.isRegularFile(src)) {
log.warn("SKIP (not file): {}", src);
return 0;
}
String fileName = src.getFileName().toString();
Path dst = tmp.resolve(type).resolve(part).resolve(fileName);
// 충돌 덮어쓰기
if (Files.exists(dst)) {
Files.delete(dst);
}
Files.createLink(dst, src);
return 1;
}
private String safe(String s) {
return (s == null || s.isBlank()) ? null : s.trim();
}
/**
* request 전체 폴더 link
*
* @param uid
* @param datasetUids
* @return
* @throws IOException
*/
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException { public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
log.info("========== buildTmpDatasetHardlink START =========="); log.info("========== buildTmpDatasetHardlink START ==========");
@@ -24,7 +117,7 @@ public class TmpDatasetService {
log.info("requestDir(raw)={}", requestDir); log.info("requestDir(raw)={}", requestDir);
Path BASE = toPath(requestDir); Path BASE = toPath(requestDir);
Path tmp = BASE.resolve("tmp").resolve(uid); Path tmp = Path.of(trainBaseDir, "tmp", uid);
log.info("BASE={}", BASE); log.info("BASE={}", BASE);
log.info("BASE exists? {}", Files.isDirectory(BASE)); log.info("BASE exists? {}", Files.isDirectory(BASE));
@@ -33,8 +126,8 @@ public class TmpDatasetService {
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0; long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
// tmp 디렉토리 준비 // tmp 디렉토리 준비
for (String type : List.of("train", "val")) { for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label")) { for (String part : List.of("input1", "input2", "label", "label-json")) {
Path dir = tmp.resolve(type).resolve(part); Path dir = tmp.resolve(type).resolve(part);
Files.createDirectories(dir); Files.createDirectories(dir);
log.info("createDirectories: {}", dir); log.info("createDirectories: {}", dir);
@@ -66,11 +159,10 @@ public class TmpDatasetService {
Path srcRoot = BASE.resolve(id); Path srcRoot = BASE.resolve(id);
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot)); log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
for (String type : List.of("train", "val")) { for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label")) { for (String part : List.of("input1", "input2", "label", "label-json")) {
Path srcDir = srcRoot.resolve(type).resolve(part); Path srcDir = srcRoot.resolve(type).resolve(part);
if (!Files.isDirectory(srcDir)) { if (!Files.isDirectory(srcDir)) {
log.warn("SKIP (not directory): {}", srcDir); log.warn("SKIP (not directory): {}", srcDir);
noDir++; noDir++;

View File

@@ -1,12 +1,13 @@
package com.kamco.cd.training.train.service; package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType; import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.model.dto.ModelTrainMngDto; import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.service.TmpDatasetService;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent; import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest; import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
@@ -17,12 +18,15 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@Service @Service
@Log4j2
@RequiredArgsConstructor @RequiredArgsConstructor
@Transactional(readOnly = true) @Transactional(readOnly = true)
public class TrainJobService { public class TrainJobService {
@@ -33,6 +37,7 @@ public class TrainJobService {
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher; private final ApplicationEventPublisher eventPublisher;
private final TmpDatasetService tmpDatasetService; private final TmpDatasetService tmpDatasetService;
private final DataSetCountersService dataSetCounters;
// 학습 결과가 저장될 호스트 디렉토리 // 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}") @Value("${train.docker.responseDir}")
@@ -42,10 +47,18 @@ public class TrainJobService {
return modelTrainMngCoreService.findModelIdByUuid(uuid); return modelTrainMngCoreService.findModelIdByUuid(uuid);
} }
/** 실행 예약 (QUEUE 등록) */ /**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @return
*/
@Transactional @Transactional
public Long enqueue(Long modelId) { public Long enqueue(Long modelId) {
// 폴더 카운트
dataSetCounters.getCount(modelId);
// 마스터 존재 확인(없으면 예외) // 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId); modelTrainMngCoreService.findModelById(modelId);
@@ -131,15 +144,18 @@ public class TrainJobService {
modelTrainMngCoreService.markStopped(modelId); modelTrainMngCoreService.markStopped(modelId);
} }
/**
* 학습 이어하기
*
* @param modelId 모델 id
* @param mode NONE 새로 시작, REQUIRE 이어하기
* @return
*/
private Long createNextAttempt(Long modelId, ResumeMode mode) { private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId); ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
if (TrainStatusType.IN_PROGRESS.getId().equals(master.getStatusCd())) { ModelTrainJobDto lastJob =
throw new IllegalStateException("이미 진행중입니다.");
}
var lastJob =
modelTrainJobCoreService modelTrainJobCoreService
.findLatestByModelId(modelId) .findLatestByModelId(modelId)
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다.")); .orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
@@ -163,7 +179,7 @@ public class TrainJobService {
// 체크포인트 탐지해서 resumeFrom 세팅 // 체크포인트 탐지해서 resumeFrom 세팅
String resumeFrom = findResumeFromOrNull(nextParams); String resumeFrom = findResumeFromOrNull(nextParams);
if (resumeFrom == null) { if (resumeFrom == null) {
throw new IllegalStateException("이어하기 체크포인트가 없습니다."); throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "이어하기 체크포인트가 없습니다.");
} }
nextParams.put("resumeFrom", resumeFrom); nextParams.put("resumeFrom", resumeFrom);
nextParams.put("resume", true); nextParams.put("resume", true);
@@ -185,55 +201,106 @@ public class TrainJobService {
REQUIRE // 이어하기 REQUIRE // 이어하기
} }
/**
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
*
* @param paramsJson
* @return
*/
public String findResumeFromOrNull(Map<String, Object> paramsJson) { public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null; if (paramsJson == null) return null;
Object out = paramsJson.get("outputFolder"); Object out = paramsJson.get("outputFolder");
if (out == null) return null; if (out == null) return null;
String outputFolder = String.valueOf(out).trim(); // uuid String outputFolder = String.valueOf(out).trim();
if (outputFolder.isEmpty()) return null; if (outputFolder.isEmpty()) return null;
// 호스트 기준 경로
Path outDir = Paths.get(responseDir, outputFolder); Path outDir = Paths.get(responseDir, outputFolder);
log.info("resume outDir response path: {}", outDir);
Path last = outDir.resolve("last_checkpoint"); Path last = outDir.resolve("last_checkpoint");
log.info("resume last response path: {}", last);
if (!Files.isRegularFile(last)) return null; if (!Files.isRegularFile(last)) return null;
try { try {
String ckptFile = Files.readString(last).trim(); // epoch_10.pth // last_checkpoint 내용 그대로 읽기
if (ckptFile.isEmpty()) return null; String containerPath = Files.readString(last).trim();
log.info("resume containerPath: {}", containerPath);
Path ckptHost = outDir.resolve(ckptFile); if (containerPath.isEmpty()) return null;
if (!Files.isRegularFile(ckptHost)) return null;
// 컨테이너 경로 반환 // 호스트 경로로 변환해서 실제 파일 존재 확인
return "/checkpoints/" + outputFolder + "/" + ckptFile; String hostPathStr = containerPath.replace("/checkpoints", responseDir);
Path hostPath = Paths.get(hostPathStr);
log.info("resume hostPath: {}", hostPath);
if (!Files.isRegularFile(hostPath)) return null;
// 3컨테이너 경로 그대로 반환
return containerPath;
} catch (Exception e) { } catch (Exception e) {
log.error("resume error", e);
return null; return null;
} }
} }
/**
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
*
* @param modelUuid
* @return
*/
@Transactional
public UUID createTmpFile(UUID modelUuid) { public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID(); UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", ""); String raw = tmpUuid.toString().toUpperCase().replace("-", "");
// model id 가져오기
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid); Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
// model 에 연결된 dataset id 가져오기
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId); List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds); List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try { try {
// 데이터셋 심볼링크 생성 // 데이터셋 심볼링크 생성
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids); // String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
// train path
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
// validation path
List<ModelTrainLinkDto> valList = modelTrainMngCoreService.findDatasetValPath(modelId);
// test path
List<ModelTrainLinkDto> testList = modelTrainMngCoreService.findDatasetTestPath(modelId);
// train 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "train", trainList);
// val 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "val", valList);
// test 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "test", testList);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq(); ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(pathUid); updateReq.setRequestPath(raw);
// 학습모델을 수정한다.
modelTrainMngCoreService.updateModelMaster(modelId, updateReq); modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); log.error(
} "createTmpFile failed. modelUuid={}, modelId={}, tmpRaw={}, datasetIdsSize={}, uidsSize={}",
modelUuid,
modelId,
raw,
(datasetIds == null ? null : datasetIds.size()),
(uids == null ? null : uids.size()),
e);
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
throw new CustomApiException(
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
}
return modelUuid; return modelUuid;
} }
} }

View File

@@ -11,24 +11,31 @@ import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult; import com.kamco.cd.training.train.dto.TrainRunResult;
import java.util.Map; import java.util.Map;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase; import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener; import org.springframework.transaction.event.TransactionalEventListener;
/** job 실행 */
@Log4j2
@Component @Component
@RequiredArgsConstructor @RequiredArgsConstructor
public class TrainJobWorker { public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService; private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
private final DockerTrainService dockerTrainService; private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
@Async @Async("trainJobExecutor")
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) { public void handle(ModelTrainJobQueuedEvent event) {
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
Long jobId = event.getJobId(); Long jobId = event.getJobId();
ModelTrainJobDto job = ModelTrainJobDto job =
@@ -50,29 +57,45 @@ public class TrainJobWorker {
String containerName = String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8); (isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
String type = isEval ? "TEST" : "TRAIN";
Integer totalEpoch = null; Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) { if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) { if (params.get("totalEpoch") != null) {
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString()); totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
} }
} }
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch); // 실행 시작 처리
modelTrainJobCoreService.markRunning(
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
log.info("[JOB] markRunning done jobId={}", jobId);
try { try {
TrainRunResult result; TrainRunResult result;
if (isEval) { if (isEval) {
// step2 진행중 처리
modelTrainMngCoreService.markStep2InProgress(modelId, jobId); modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid")); String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch"); int epoch = (int) params.get("epoch");
String datasetFolder = String.valueOf(params.get("datasetFolder"));
String outputFolder = String.valueOf(params.get("outputFolder"));
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null); EvalRunRequest evalReq = new EvalRunRequest();
result = dockerTrainService.runEvalSync(evalReq, containerName); evalReq.setUuid(uuid);
evalReq.setEpoch(epoch);
evalReq.setTimeoutSeconds(null);
evalReq.setDatasetFolder(datasetFolder);
evalReq.setOutputFolder(outputFolder);
log.info("[JOB] selected test epoch={}", epoch);
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
} else { } else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId); modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params); TrainRunRequest trainReq = toTrainRunRequest(params);
// 도커 실행 후 로그 수집
result = dockerTrainService.runTrainSync(trainReq, containerName); result = dockerTrainService.runTrainSync(trainReq, containerName);
} }
@@ -86,31 +109,43 @@ public class TrainJobWorker {
} }
if (result.getExitCode() == 0) { if (result.getExitCode() == 0) {
// 성공 처리
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode()); modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) { if (isEval) {
// step2 완료처리
modelTrainMngCoreService.markStep2Success(modelId); modelTrainMngCoreService.markStep2Success(modelId);
// 결과 csv 파일 정보 등록
modelTestMetricsJobService.findTestValidMetricCsvFiles();
} else { } else {
modelTrainMngCoreService.markStep1Success(modelId); modelTrainMngCoreService.markStep1Success(modelId);
// 결과 csv 파일 정보 등록
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
} }
} else { } else {
String failMsg = result.getStatus() + "\n" + result.getLogs();
// 실패 처리
modelTrainJobCoreService.markFailed( modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) { if (isEval) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode()); modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else { } else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode()); modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
} }
} }
} catch (Exception e) { } catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.toString()); modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) { if ("EVAL".equals(params.get("jobType"))) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage()); modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else { } else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, e.getMessage()); modelTrainMngCoreService.markError(modelId, e.getMessage());
} }
} }

View File

@@ -58,11 +58,15 @@ file:
dataset-dir: /home/kcomu/data/request/ dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/ dataset-tmp-dir: ${file.dataset-dir}tmp/
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
pt-FileName: yolov8_6th-6m.pt
train: train:
docker: docker:
image: "kamco-cd-train:love_latest" image: kamco-cd-train:latest
requestDir: "/home/kcomu/data/request" requestDir: /home/kcomu/data/request
responseDir: "/home/kcomu/data/response" responseDir: /home/kcomu/data/response
containerPrefix: "kamco-cd-train" basePath: /home/kcomu/data
shmSize: "16g" containerPrefix: kamco-cd-train
shmSize: 16g
ipcHost: true ipcHost: true

View File

@@ -4,7 +4,7 @@ spring:
on-profile: prod on-profile: prod
jpa: jpa:
show-sql: false show-sql: true
hibernate: hibernate:
ddl-auto: validate ddl-auto: validate
properties: properties:
@@ -12,23 +12,36 @@ spring:
default_batch_fetch_size: 100 # ✅ 성능 - N+1 쿼리 방지 default_batch_fetch_size: 100 # ✅ 성능 - N+1 쿼리 방지
order_updates: true # ✅ 성능 - 업데이트 순서 정렬로 데드락 방지 order_updates: true # ✅ 성능 - 업데이트 순서 정렬로 데드락 방지
use_sql_comments: true # ⚠️ 선택 - SQL에 주석 추가 (디버깅용) use_sql_comments: true # ⚠️ 선택 - SQL에 주석 추가 (디버깅용)
format_sql: true # ⚠️ 선택 - SQL 포맷팅 (가독성)
datasource: datasource:
url: jdbc:postgresql://10.100.0.10:25432/temp url: jdbc:postgresql://127.0.01:15432/kamco_training_db
username: temp # url: jdbc:postgresql://localhost:15432/kamco_training_db
password: temp123! username: kamco_training_user
password: kamco_training_user_2025_!@#
hikari: hikari:
minimum-idle: 10 minimum-idle: 10
maximum-pool-size: 20 maximum-pool-size: 20
connection-timeout: 60000 # 60초 연결 타임아웃
idle-timeout: 300000 # 5분 유휴 타임아웃
max-lifetime: 1800000 # 30분 최대 수명
leak-detection-threshold: 60000 # 연결 누수 감지
transaction:
default-timeout: 300 # 5분 트랜잭션 타임아웃
jwt: jwt:
secret: "kamco_token_prod_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a" secret: "kamco_token_dev_dfc6446d-68fc-4eba-a2ff-c80a14a0bf3a"
access-token-validity-in-ms: 86400000 # 1일 access-token-validity-in-ms: 86400000 # 1일
refresh-token-validity-in-ms: 604800000 # 7일 refresh-token-validity-in-ms: 604800000 # 7일
token: token:
refresh-cookie-name: kamco # 개발용 쿠키 이름 refresh-cookie-name: kamco # 개발용 쿠키 이름
refresh-cookie-secure: true # 로컬 http 테스트면 false refresh-cookie-secure: false # 로컬 http 테스트면 false
springdoc:
swagger-ui:
persist-authorization: true # 스웨거 새로고침해도 토큰 유지, 로컬스토리지에 저장
member: member:
init_password: kamco1234! init_password: kamco1234!
@@ -44,11 +57,16 @@ file:
dataset-dir: /home/kcomu/data/request/ dataset-dir: /home/kcomu/data/request/
dataset-tmp-dir: ${file.dataset-dir}tmp/ dataset-tmp-dir: ${file.dataset-dir}tmp/
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
pt-FileName: yolov8_6th-6m.pt
train: train:
docker: docker:
image: "kamco-cd-train:love_latest" image: kamco-cd-train:latest
requestDir: "/home/kcomu/data/request" requestDir: /home/kcomu/data/request
responseDir: "/home/kcomu/data/response" responseDir: /home/kcomu/data/response
containerPrefix: "kamco-cd-train" basePath: /home/kcomu/data
shmSize: "16g" containerPrefix: kamco-cd-train
shmSize: 16g
ipcHost: true ipcHost: true

View File

@@ -51,8 +51,8 @@ logging:
level: level:
org: org:
springframework: springframework:
web: DEBUG web: INFO
security: DEBUG security: INFO
root: INFO root: INFO
# actuator # actuator
management: management:

View File

@@ -30396,6 +30396,8 @@ ALTER TABLE ONLY public.tb_menu
ADD CONSTRAINT fksw914diut87r7lfykekc7xm2a FOREIGN KEY (parent_menu_uid) REFERENCES public.tb_menu(menu_uid); ADD CONSTRAINT fksw914diut87r7lfykekc7xm2a FOREIGN KEY (parent_menu_uid) REFERENCES public.tb_menu(menu_uid);
-- Completed on 2025-12-26 16:11:11 KST -- Completed on 2025-12-26 16:11:11 KST
-- --
@@ -30404,3 +30406,5 @@ ALTER TABLE ONLY public.tb_menu
\unrestrict IYrUYfSgA4Fo2gubHcb84jDXfbBZEIiOZnyLtZgnMi641GaRQa5QDogarpTr7IG \unrestrict IYrUYfSgA4Fo2gubHcb84jDXfbBZEIiOZnyLtZgnMi641GaRQa5QDogarpTr7IG

View File

@@ -0,0 +1,87 @@
<!doctype html>
<html lang="ko">
<head>
<meta charset="utf-8" />
<title>학습데이터 ZIP 다운로드</title>
</head>
<body>
<h3>학습데이터 ZIP 다운로드</h3>
UUID:
<input id="uuid" value="95cb116c-380a-41c0-98d8-4d1142f15bbf" />
<br><br>
modelVer:
<input id="moderVer" value="G2.HPs_0001.95cb116c-380a-41c0-98d8-4d1142f15bbf" />
<br><br>
JWT Token:
<input id="token" style="width:600px;" placeholder="Bearer 토큰 붙여넣기" />
<br><br>
<button onclick="download()">다운로드</button>
<br><br>
<progress id="bar" value="0" max="100" style="width:400px;"></progress>
<div id="status"></div>
<script>
async function download() {
const uuid = document.getElementById("uuid").value.trim();
const moderVer = document.getElementById("moderVer").value.trim();
const token = document.getElementById("token").value.trim();
if (!uuid) {
alert("UUID 입력하세요");
return;
}
if (!token) {
alert("토큰 입력하세요");
return;
}
const url = `/api/models/download/${uuid}`;
const res = await fetch(url, {
headers: {
"Authorization": token.startsWith("Bearer ")
? token
: `Bearer ${token}`,
"kamco-download-uuid": uuid
}
});
if (!res.ok) {
document.getElementById("status").innerText =
"실패: " + res.status;
return;
}
const total = parseInt(res.headers.get("Content-Length") || "0", 10);
const reader = res.body.getReader();
const chunks = [];
let received = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
received += value.length;
if (total) {
document.getElementById("bar").value =
(received / total) * 100;
}
}
const blob = new Blob(chunks);
const a = document.createElement("a");
a.href = URL.createObjectURL(blob);
a.download = moderVer + ".zip";
a.click();
document.getElementById("status").innerText = "완료 ✅";
}
</script>
</body>
</html>