208 Commits

Author SHA1 Message Date
265813e6f7 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-03-10 14:19:40 +09:00
8190a6e9c8 unzip 2026-03-10 14:19:22 +09:00
e9f8bb37fa spotless 적용 2026-03-03 23:06:45 +09:00
f3c822587f spotless 적용 2026-03-03 23:02:53 +09:00
335f0dbb9b spotless 적용 2026-03-03 23:01:22 +09:00
69eaba1a83 하드링크 수정 2026-03-03 22:51:10 +09:00
365ad81cad 리커버리 삭제 2026-02-28 01:24:34 +09:00
9dfa54fbf9 리커버리 추가 2026-02-28 01:01:38 +09:00
12f6bb7154 하드링크 수정 2026-02-27 23:31:04 +09:00
aa3af4e9d0 하드링크 로그 추가 2026-02-27 23:12:00 +09:00
7ca37bf1e4 하드링크 로그 추가 2026-02-27 22:51:27 +09:00
901dde066d 하이파라미터 상세조회 수정 2026-02-24 16:54:58 +09:00
cb0a38274a val_interval 기본값 1로 수정 2026-02-24 16:24:16 +09:00
b8194df9ae 학습 실패여부 확인 기능 추가 2026-02-24 15:43:35 +09:00
7c5f07683e 학습 실패여부 확인 기능 추가 2026-02-24 15:10:48 +09:00
159fb281d4 최근 사용일시 업데이트 2026-02-23 15:40:24 +09:00
97192ff811 최근 사용일시 업데이트 2026-02-23 15:39:25 +09:00
4f3fb675be 하이퍼 파라미터 사용회수 카운트 기능 추가 및 조회 수정 2026-02-23 15:37:28 +09:00
e6caea05b3 하이퍼 파라미터 사용회수 카운트 기능 추가 및 조회 수정 2026-02-23 15:19:40 +09:00
fd63824edc 하이퍼 파라미터 미사용 컬럼제거, 사용횟수 컬럼 추가 2026-02-23 14:30:29 +09:00
8a44df26b8 하이퍼 파라미터 상세조회 삭제여부 조건 제거 2026-02-23 14:17:34 +09:00
cb97c5e59e 하이퍼 파라미터 dto 주석 수정 2026-02-23 14:13:25 +09:00
8f75b16dc6 학습실행 주석 추가 2026-02-23 12:30:54 +09:00
c2978e41c2 전이학습 상세 수정 2026-02-20 18:34:32 +09:00
07429dbe8e 전이학습 상세 수정 2026-02-20 18:22:19 +09:00
83859bb9fe 전이학습 상세 - before dataset 추가 2026-02-20 16:05:29 +09:00
564a99448c best epoch 파일 선택 수정 2026-02-20 15:41:34 +09:00
38ae6e5575 best epoch 파일 선택 수정 2026-02-20 15:31:33 +09:00
40fe98ae0c best epoch 파일 선택 수정 2026-02-20 15:15:12 +09:00
255ff10a56 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-20 14:30:42 +09:00
f674f73330 중복 수정 제거 2026-02-20 14:30:34 +09:00
db2bc32e7d test metrics insert 로직 수정 2026-02-20 14:24:23 +09:00
37786a1e44 선택한 테스트 에폭 로그 추가 2026-02-20 14:13:49 +09:00
901ea83fb7 test 선택한 에폭 log 확인 추가 2026-02-20 14:13:22 +09:00
832e1b5681 tmp 하드링크 수정 2026-02-20 13:36:48 +09:00
4f16355cda tmp 하드링크 수정 2026-02-20 12:29:57 +09:00
df46a8f79f Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-20 12:21:50 +09:00
fcd48831c5 tmp 하드링크 수정 2026-02-20 12:21:41 +09:00
62c9d73b94 test json 수정 2026-02-20 12:20:15 +09:00
68c0e634c5 ing-cnt 로직에 step2도 추가, transactional 2026-02-20 12:05:20 +09:00
ad421e3c74 비밀번호 변경 security 로직 수정 2026-02-20 11:36:21 +09:00
46db1512a6 test 실행 시 회차별 데이터 적재하기 2026-02-19 18:18:12 +09:00
2034a8fcb2 LogErrorLevel -> CodeExpose 추가 2026-02-19 17:35:15 +09:00
bf212842d8 모델학습관리 > 모델별 진행 상황 API 추가 2026-02-19 17:17:11 +09:00
d2ca94ea55 모델학습관리 > 목록 API 메모,작성자 추가로 인한 수정 2026-02-19 15:34:18 +09:00
5ddf6dfeeb 모델학습 2단계 패키징 시작,종료일시,상태 로직 추가 2026-02-19 14:43:14 +09:00
5e13c0b396 공통코드 common-code 로 prefix 변경 2026-02-19 11:38:56 +09:00
435f60dcac 로그관리 로직 커밋 2026-02-19 11:13:40 +09:00
5f5eabca19 압축해제 시, 동일 폴더가 있으면 삭제 후 재업로드 2026-02-18 16:36:46 +09:00
413631840f 학습데이터 다운로드 security 제외하기 2026-02-18 16:29:28 +09:00
c7f63d1ad1 압축 해제한 폴더의 갯수 맞는지 log 찍기 + 갯수 맞지 않으면 exception 리턴 2026-02-18 16:22:32 +09:00
7529d23488 업로드 시 uid로 중복체크 -> 삭제인 row는 제외하기 2026-02-18 15:40:14 +09:00
cb3e51d712 업로드 시 exception 메세지 처리, 에폭 10 이상으로 실행되게 수정 2026-02-18 15:28:29 +09:00
99a4597b5f train 결과 +1 했던 거 제거하기 2026-02-18 14:50:53 +09:00
d9da0d4610 1단계 실행 시, 시작시간 update 추가 2026-02-18 13:05:59 +09:00
22c481556c 하이퍼 파라미터 수정 2026-02-13 15:00:30 +09:00
0798b352c7 하이퍼 파라미터 수정 2026-02-13 14:46:59 +09:00
5b074bdb81 하이퍼 파라미터 수정 2026-02-13 14:42:39 +09:00
28919345c2 하이퍼 파라미터 수정 2026-02-13 14:37:44 +09:00
aa0552aaa7 하이퍼 파라미터 수정 2026-02-13 14:30:45 +09:00
5d0aca14a6 사용가능 용량 API 수정 2026-02-13 14:19:09 +09:00
af8d59ddfa 이어하기 수정 2026-02-13 14:08:35 +09:00
4f24e09c57 이어하기 수정 2026-02-13 14:04:33 +09:00
4da477706f 이어하기 수정 2026-02-13 13:57:01 +09:00
a070566048 이어하기 로그 수정 2026-02-13 13:26:54 +09:00
a5b3ae613f 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 13:23:53 +09:00
979af088be 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 13:17:27 +09:00
e5a1cab36b Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 12:53:04 +09:00
bb67996742 text 2026-02-13 12:53:00 +09:00
1981d6d1ce Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-13 12:53:00 +09:00
47f4ffd4db 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:52:11 +09:00
195856b846 flush 추가해보기 2026-02-13 12:46:54 +09:00
124da48e51 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:38:55 +09:00
02724e9508 주석한거 원복 2026-02-13 12:24:50 +09:00
7ed91ccab9 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 12:23:12 +09:00
a7c13b985d responsePath 셋팅 삭제 2026-02-13 12:23:01 +09:00
352a28b87f 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:20:48 +09:00
bf8515163c 주석 처리 2026-02-13 12:10:08 +09:00
2691f6ce16 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 12:00:42 +09:00
7e5aa5e713 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 11:57:48 +09:00
060a815e1c 트랜젝션처리 임시폴더 uid업데이트 2026-02-13 11:55:35 +09:00
1eb4d04779 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-13 10:50:35 +09:00
f30c0c6d45 다운로드 시 Access-Control-Expose-Headers 추가 2026-02-13 10:50:28 +09:00
12994aab60 파일 count 기능 추가 2026-02-13 10:44:49 +09:00
11d3afe295 파일 count 기능 추가 2026-02-13 10:38:24 +09:00
1e62a8b097 학습실행 step1 할 때 best epoch 업데이트 2026-02-13 10:15:04 +09:00
26a4623aa8 학습데이터 목록 파일 단위 MB 나오게 하기 2026-02-13 09:42:36 +09:00
ce6e4f5aea tmp 파일 링크 수정 2026-02-13 09:10:45 +09:00
c2215836c0 tmp 파일 링크 수정 2026-02-13 08:44:17 +09:00
8c19c996f7 tmp 파일 링크 수정 2026-02-13 08:33:36 +09:00
b5ce3ab1fb 이어하기 수정 2026-02-12 23:01:56 +09:00
e1ceb769dd 학습데이터 다운로드 파일 정보 API 추가 2026-02-12 22:47:13 +09:00
4219b88fb3 학습데이터 다운로드 API 추가 2026-02-12 22:25:55 +09:00
4f94c99b64 이어하기 수정 2026-02-12 22:09:54 +09:00
d42e1afbd4 스케줄러 api 수동 호출 2026-02-12 21:51:53 +09:00
b3b8016673 csv 결과 받아오는 것 변경 2026-02-12 21:45:38 +09:00
79e8259f28 파라미터 변경 2026-02-12 21:30:03 +09:00
032c82c2f0 file 경로 넣기 2026-02-12 21:17:10 +09:00
6204a6e5fa Merge remote-tracking branch 'origin/develop' into feat/training_260202
# Conflicts:
#	src/main/resources/application-prod.yml
2026-02-12 21:15:21 +09:00
4d9c9a86b4 패키징 zip파일 만들기 커밋 2026-02-12 21:09:40 +09:00
83204abfe9 Merge pull request '성공시 csv 파일 테이블에 저장 연결' (#74) from feat/training_260202 into develop
Reviewed-on: #74
2026-02-12 21:01:48 +09:00
5b682c1386 성공시 csv 파일 테이블에 저장 연결 2026-02-12 21:01:24 +09:00
452494d44d Merge pull request '테스트 실행 경로 수정' (#73) from feat/training_260202 into develop
Reviewed-on: #73
2026-02-12 20:49:30 +09:00
8ada26448b 테스트 실행 경로 수정 2026-02-12 20:49:14 +09:00
e442f105bc Merge pull request '도커명 변경' (#72) from feat/training_260202 into develop
Reviewed-on: #72
2026-02-12 20:37:00 +09:00
5e0a771848 도커명 변경 2026-02-12 20:36:38 +09:00
b4c2685059 Merge pull request '도커 설정 추가' (#71) from feat/training_260202 into develop
Reviewed-on: #71
2026-02-12 20:25:16 +09:00
e238f3ca88 도커 설정 추가 2026-02-12 20:24:57 +09:00
97b06eb3b3 Merge pull request '임시파일생성 경로 수정' (#70) from feat/training_260202 into develop
Reviewed-on: #70
2026-02-12 20:03:32 +09:00
ad32ca18ca 임시파일생성 경로 수정 2026-02-12 20:03:03 +09:00
98a1283ebe Merge pull request '임시파일생성 경로 수정' (#69) from feat/training_260202 into develop
Reviewed-on: #69
2026-02-12 19:36:06 +09:00
a10fccaae3 임시파일생성 경로 수정 2026-02-12 19:35:47 +09:00
c3c9191d9d Merge pull request 'hyperparam_with_modeltype' (#68) from feat/dean/hyperparam_with_modelType-bug into develop
Reviewed-on: #68
2026-02-12 19:30:29 +09:00
9fd5a15a72 hyperparam_with_modeltype 2026-02-12 19:30:08 +09:00
12f9de7367 hyperparam_with_modeltype 2026-02-12 19:16:24 +09:00
5455da1e96 hyperparam_with_modeltype 2026-02-12 19:16:13 +09:00
9e803661cd Merge pull request 'feat/training_260202' (#67) from feat/training_260202 into develop
Reviewed-on: #67
2026-02-12 19:14:39 +09:00
b0cf9e77ec Merge branch 'develop' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into develop 2026-02-12 19:14:10 +09:00
c92426aefc Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 19:14:09 +09:00
d5b2b8ecec hyperparam_with_modeltype 2026-02-12 19:14:01 +09:00
6185a18a7c 모델목록 검색 조건 상태값 변경 2026-02-12 19:13:56 +09:00
49d3e37458 Merge pull request '임시파일생성 경로 수정' (#66) from feat/training_260202 into develop
Reviewed-on: #66
2026-02-12 19:12:37 +09:00
1fb10830b9 임시파일생성 경로 수정 2026-02-12 19:11:51 +09:00
d7766edd24 Merge pull request 'return 형식 수정' (#65) from feat/training_260202 into develop
Reviewed-on: #65
2026-02-12 18:59:37 +09:00
0bc4453c9c hyperparam_with_modeltype 2026-02-12 18:56:32 +09:00
ae0d30e5da return 형식 수정 2026-02-12 18:55:42 +09:00
37d776dd2c Merge pull request 'hyperparam_with_modeltype' (#64) from feat/dean/hyperparam_with_modelType into develop
Reviewed-on: #64
2026-02-12 18:50:32 +09:00
0c34ea7dcb hyperparam_with_modeltype 2026-02-12 18:48:14 +09:00
3106d36431 Merge pull request '업로드 시 같은 uid로 업로드하지 못하게 조건 추가' (#63) from feat/training_260202 into develop
Reviewed-on: #63
2026-02-12 18:44:49 +09:00
ed48f697a4 업로드 시 같은 uid로 업로드하지 못하게 조건 추가 2026-02-12 18:44:04 +09:00
da92b28d97 Merge pull request '임시파일생성 소프트링크에서 하드링크로 변경' (#62) from feat/training_260202 into develop
Reviewed-on: #62
2026-02-12 18:20:30 +09:00
6c865d26fd 임시파일생성 소프트링크에서 하드링크로 변경 2026-02-12 18:18:44 +09:00
e3f00876f1 Merge pull request '문제되는 하이퍼파라미터 주석처리' (#61) from feat/training_260202 into develop
Reviewed-on: #61
2026-02-12 17:53:11 +09:00
16e156b5b4 문제되는 하이퍼파라미터 주석처리 2026-02-12 17:52:42 +09:00
60962bbc75 Merge pull request '학습실행 mount 경로 수정' (#60) from feat/training_260202 into develop
Reviewed-on: #60
2026-02-12 17:44:15 +09:00
6a939118ff 임시폴더생성 api 추가 2026-02-12 17:43:41 +09:00
64d37dcc08 Merge pull request '임시폴더생성 api 추가' (#59) from feat/training_260202 into develop
Reviewed-on: #59
2026-02-12 17:23:53 +09:00
0c0ae16c2b 임시폴더생성 api 추가 2026-02-12 17:23:34 +09:00
a2490f30e6 Merge pull request '임시폴더생성 api 수정' (#58) from feat/training_260202 into develop
Reviewed-on: #58
2026-02-12 17:14:52 +09:00
953f95aed6 임시폴더생성 api 추가 2026-02-12 17:14:26 +09:00
bd04e1f4e8 Merge pull request '임시폴더생성 api 추가' (#57) from feat/training_260202 into develop
Reviewed-on: #57
2026-02-12 17:03:39 +09:00
85633c8bab 임시폴더생성 api 추가 2026-02-12 17:03:21 +09:00
5fc15937c0 Merge pull request 'feat/training_260202' (#56) from feat/training_260202 into develop
Reviewed-on: #56
2026-02-12 17:00:08 +09:00
8b3940b446 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-12 16:59:44 +09:00
201cfefb6b 임시폴더생성 api 추가 2026-02-12 16:59:39 +09:00
9958b0999a csv 읽는 경로 수정하기, 변수명 수정 2026-02-12 16:58:28 +09:00
3547c28361 Merge pull request 'feat/training_260202' (#55) from feat/training_260202 into develop
Reviewed-on: #55
2026-02-12 16:56:23 +09:00
6c70bfed18 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202
# Conflicts:
#	src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java
2026-02-12 16:55:52 +09:00
95a75e63f4 임시폴더생성 api 추가 2026-02-12 16:55:10 +09:00
2a1dbee290 Merge pull request '모델학습 1단계 실행중인 것이 있는지 count API' (#54) from feat/training_260202 into develop
Reviewed-on: #54
2026-02-12 16:51:09 +09:00
384a321bf3 모델학습 1단계 실행중인 것이 있는지 count API 2026-02-12 16:50:40 +09:00
f4e97d389b Merge pull request 'file 확인 API 수정' (#53) from feat/training_260202 into develop
Reviewed-on: #53
2026-02-12 16:42:20 +09:00
590810ff0a file 확인 API 수정 2026-02-12 16:41:40 +09:00
a01c872982 Merge pull request 'feat/training_260202' (#52) from feat/training_260202 into develop
Reviewed-on: #52
2026-02-12 16:15:11 +09:00
905a245070 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 16:14:45 +09:00
860ce35a8f docker mount 경로 추가 2026-02-12 16:14:19 +09:00
7f3f5dca40 Merge pull request 'feat/training_260202' (#51) from feat/training_260202 into develop
Reviewed-on: #51
2026-02-12 16:13:19 +09:00
4a0a4e35ed 학습 실행 수정 2026-02-12 16:12:58 +09:00
ae055dca1e 모델등록 수정 2026-02-12 16:01:14 +09:00
26e8e1492f Merge pull request 'feat/training_260202' (#50) from feat/training_260202 into develop
Reviewed-on: #50
2026-02-12 15:52:09 +09:00
8fa722011c 모델등록 수정 2026-02-12 15:51:54 +09:00
17d47d6200 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 15:47:10 +09:00
e178f58fe2 chunk save log 추가 2026-02-12 15:47:06 +09:00
cd0cf5726d Merge pull request 'feat/training_260202' (#49) from feat/training_260202 into develop
Reviewed-on: #49
2026-02-12 15:44:11 +09:00
8e4bea53da 모델등록 수정 2026-02-12 15:43:52 +09:00
7a22d8ba73 containerName 생성 변경 2026-02-12 15:39:12 +09:00
2df4a7a80b csv 파일 읽는 경로 읽어서 수정, train은 epoch + 1 해서 저장 2026-02-12 15:24:30 +09:00
b451f697bc 모델 마스터 테이블 request,response 경로 추가 2026-02-12 14:59:35 +09:00
7e9c867f34 Merge pull request '모델 등록할 때 step1State를 READY로 업데이트' (#48) from feat/training_260202 into develop
Reviewed-on: #48
2026-02-12 14:35:52 +09:00
130e85f8a1 모델 등록할 때 step1State를 READY로 업데이트 2026-02-12 14:35:17 +09:00
9e713cb49d Merge pull request '업로드 로직 재수정' (#47) from feat/training_260202 into develop
Reviewed-on: #47
2026-02-12 14:21:57 +09:00
51dfa97900 업로드 로직 재수정 2026-02-12 14:21:08 +09:00
87c6b599b4 Merge pull request 'feat/training_260202' (#46) from feat/training_260202 into develop
Reviewed-on: #46
2026-02-12 12:10:04 +09:00
f50855a822 Merge branch 'feat/training_260202' of https://kamco.git.gs.dabeeo.com/MVPTeam/kamco-train-api into feat/training_260202 2026-02-12 12:08:04 +09:00
8d416317a8 베스트 에폭 API, 2단계 실행 시 best epoch 업데이트 2026-02-12 12:07:44 +09:00
22aa071476 Merge pull request 'feat/training_260202' (#45) from feat/training_260202 into develop
Reviewed-on: #45
2026-02-12 12:06:04 +09:00
a83bd09f8f containerName 생성 변경 2026-02-12 12:05:30 +09:00
96035f864a containerName 생성 변경 2026-02-12 11:42:38 +09:00
fd7dfd7e7f containerName 생성 변경 2026-02-12 11:10:28 +09:00
190b93bee8 실행 오류 수정 2026-02-12 10:58:51 +09:00
c5f19cc961 실행 오류 수정 2026-02-12 10:58:32 +09:00
c56c0ca605 실행 오류 수정 2026-02-12 10:58:26 +09:00
c6e721aa37 실행 오류 수정 2026-02-12 10:58:12 +09:00
6572e17f00 실행 오류 수정 2026-02-12 10:51:15 +09:00
be6365807c Merge pull request '실행 오류 수정' (#43) from feat/training_260202 into develop
Reviewed-on: #43
2026-02-12 10:20:05 +09:00
d2fff7dfde 실행 오류 수정 2026-02-12 10:19:44 +09:00
f66bc22c95 Merge pull request '실행 오류 수정' (#42) from feat/training_260202 into develop
Reviewed-on: #42
2026-02-12 10:14:54 +09:00
3367d0e7be 실행 오류 수정 2026-02-12 10:14:32 +09:00
352ec6ccb0 Merge pull request 'feat/training_260202' (#41) from feat/training_260202 into develop
Reviewed-on: #41
2026-02-12 09:53:02 +09:00
6a989255a3 모델별 데이터셋 목록 - G2,G3 dataTypeName 추가 2026-02-12 09:52:24 +09:00
878b21573f 테스트 실행 추가 2026-02-11 22:00:35 +09:00
0602db1436 Merge pull request '테스트 실행 추가' (#40) from feat/training_260202 into develop
Reviewed-on: #40
2026-02-11 21:58:58 +09:00
2f8bd1f98c 테스트 실행 추가 2026-02-11 21:58:25 +09:00
75231ccbba Merge pull request '추론 실행 추가' (#39) from feat/training_260202 into develop
Reviewed-on: #39
2026-02-11 20:22:01 +09:00
1249a80da5 추론 실행 추가 2026-02-11 20:21:25 +09:00
00c78eb42f Merge pull request '성능정보 그래프 데이터 API 추가' (#38) from feat/training_260202 into develop
Reviewed-on: #38
2026-02-11 19:52:23 +09:00
35767adba1 성능정보 그래프 데이터 API 추가 2026-02-11 19:52:00 +09:00
47a2a159ef Merge pull request 'test metrics 스케줄 추가' (#37) from feat/training_260202 into develop
Reviewed-on: #37
2026-02-11 19:10:37 +09:00
95548223cd test metrics 스케줄 추가 2026-02-11 19:09:58 +09:00
2debdc5312 Merge pull request 'feat/training_260202' (#36) from feat/training_260202 into develop
Reviewed-on: #36
2026-02-11 18:51:01 +09:00
207cc47f1b 스케줄 주석 2026-02-11 18:50:43 +09:00
b6338bce8e 테이블 구조 변경 2026-02-11 18:49:59 +09:00
2cfa2adcf5 tb_model_master 컬럼 추가 2026-02-11 17:21:48 +09:00
d7e19abfc9 uploadRate 로직 수정 2026-02-11 17:06:02 +09:00
c843703ee7 Merge pull request 'file 가져오기 86 호출하는 거로 추가' (#35) from feat/training_260202 into develop
Reviewed-on: #35
2026-02-11 16:53:25 +09:00
133ea6b1ba file 가져오기 86 호출하는 거로 추가 2026-02-11 16:49:48 +09:00
0df977ae81 Merge pull request '업로드 로직 86으로 수행하기 수정' (#34) from feat/training_260202 into develop
Reviewed-on: #34
2026-02-11 16:33:03 +09:00
3e39006822 업로드 로직 86으로 수행하기 수정 2026-02-11 16:32:40 +09:00
105 changed files with 7336 additions and 297 deletions

View File

@@ -1,6 +1,11 @@
# Stage 1: Build stage (gradle build는 Jenkins에서 이미 수행)
FROM eclipse-temurin:21-jre-jammy
# docker CLI 설치 (컨테이너에서 호스트 Docker 제어용) 260212 추가
RUN apt-get update && \
apt-get install -y --no-install-recommends docker.io ca-certificates && \
rm -rf /var/lib/apt/lists/*
# 작업 디렉토리 설정
WORKDIR /app

View File

@@ -3,6 +3,7 @@ plugins {
id 'org.springframework.boot' version '3.5.7'
id 'io.spring.dependency-management' version '1.1.7'
id 'com.diffplug.spotless' version '6.25.0'
id 'idea'
}
group = 'com.kamco.cd'
@@ -21,11 +22,23 @@ configurations {
}
}
// QueryDSL 생성된 소스 디렉토리 정의
def generatedSourcesDir = file("$buildDir/generated/sources/annotationProcessor/java/main")
repositories {
mavenCentral()
maven { url "https://repo.osgeo.org/repository/release/" }
}
// Gradle이 생성된 소스를 컴파일 경로에 포함하도록 설정
sourceSets {
main {
java {
srcDirs += generatedSourcesDir
}
}
}
dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-web'
@@ -84,7 +97,22 @@ dependencies {
implementation 'org.reflections:reflections:0.10.2'
implementation 'com.jcraft:jsch:0.1.55'
implementation 'org.apache.commons:commons-csv:1.10.0'
}
// IntelliJ가 생성된 소스를 인식하도록 설정
idea {
module {
// 소스 디렉토리로 인식
sourceDirs += generatedSourcesDir
// Generated Sources Root로 마킹 (IntelliJ에서 특별 처리)
generatedSourceDirs += generatedSourcesDir
// 소스 및 Javadoc 다운로드
downloadJavadoc = true
downloadSources = true
}
}
configurations.configureEach {
@@ -95,6 +123,21 @@ tasks.named('test') {
useJUnitPlatform()
}
// 컴파일 전 생성된 소스 디렉토리 생성 보장
tasks.named('compileJava') {
doFirst {
generatedSourcesDir.mkdirs()
}
}
// 생성된 소스 정리 태스크
tasks.register('cleanGeneratedSources', Delete) {
delete generatedSourcesDir
}
tasks.named('clean') {
dependsOn 'cleanGeneratedSources'
}
bootJar {
archiveFileName = 'ROOT.jar'

View File

@@ -14,6 +14,8 @@ services:
- /mnt/nfs_share/images:/app/original-images
- /mnt/nfs_share/model_output:/app/model-outputs
- /mnt/nfs_share/train_dataset:/app/train-dataset
- /home/kcomu/data:/home/kcomu/data
- /var/run/docker.sock:/var/run/docker.sock
networks:
- kamco-cds
restart: unless-stopped

View File

@@ -23,7 +23,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private final UserDetailsService userDetailsService;
private static final AntPathMatcher PATH_MATCHER = new AntPathMatcher();
private static final String[] EXCLUDE_PATHS = {
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password"
// "/api/auth/signin", "/api/auth/refresh", "/api/auth/logout", "/api/members/*/password"
"/api/auth/signin", "/api/auth/refresh", "/api/auth/logout"
};
@Override

View File

@@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.*;
@Tag(name = "공통코드 관리", description = "공통코드 관리 API")
@RestController
@RequiredArgsConstructor
@RequestMapping("/api/code")
@RequestMapping("/api/common-code")
public class CommonCodeApiController {
private final CommonCodeService commonCodeService;

View File

@@ -0,0 +1,48 @@
package com.kamco.cd.training.common.download;
import com.kamco.cd.training.common.download.dto.DownloadSpec;
import com.kamco.cd.training.common.utils.UserUtil;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
@Service
@RequiredArgsConstructor
public class DownloadExecutor {
private final UserUtil userUtil;
public ResponseEntity<StreamingResponseBody> stream(DownloadSpec spec) throws IOException {
if (!Files.isReadable(spec.filePath())) {
return ResponseEntity.notFound().build();
}
StreamingResponseBody body =
os -> {
try (InputStream in = Files.newInputStream(spec.filePath())) {
in.transferTo(os);
os.flush();
} catch (Exception e) {
// 고용량은 중간 끊김 흔하니까 throw 금지
}
};
String fileName =
spec.downloadName() != null
? spec.downloadName()
: spec.filePath().getFileName().toString();
return ResponseEntity.ok()
.contentType(
spec.contentType() != null ? spec.contentType() : MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
.body(body);
}
}

View File

@@ -0,0 +1,19 @@
package com.kamco.cd.training.common.download;
import org.springframework.util.AntPathMatcher;
public final class DownloadPaths {
private DownloadPaths() {}
public static final String[] PATTERNS = {
"/api/inference/download/**", "/api/training-data/stage/download/**"
};
public static boolean matches(String uri) {
AntPathMatcher m = new AntPathMatcher();
for (String p : PATTERNS) {
if (m.match(p, uri)) return true;
}
return false;
}
}

View File

@@ -0,0 +1,81 @@
package com.kamco.cd.training.common.download;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.ResourceRegion;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRange;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
@Component
public class RangeDownloadResponder {
public ResponseEntity<?> buildZipResponse(
Path filePath, String downloadFileName, HttpServletRequest request) throws IOException {
if (!Files.isRegularFile(filePath)) {
return ResponseEntity.notFound().build();
}
long totalSize = Files.size(filePath);
Resource resource = new FileSystemResource(filePath);
String disposition = "attachment; filename=\"" + downloadFileName + "\"";
String rangeHeader = request.getHeader(HttpHeaders.RANGE);
// 🔥 공통 헤더 (여기 고정)
ResponseEntity.BodyBuilder base =
ResponseEntity.ok()
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
.header("Access-Control-Expose-Headers", "Content-Disposition")
.header("X-Accel-Buffering", "no");
if (rangeHeader == null || rangeHeader.isBlank()) {
return base.contentLength(totalSize).body(resource);
}
List<HttpRange> ranges;
try {
ranges = HttpRange.parseRanges(rangeHeader);
} catch (IllegalArgumentException ex) {
return ResponseEntity.status(416)
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
.header("X-Accel-Buffering", "no")
.build();
}
HttpRange range = ranges.get(0);
long start = range.getRangeStart(totalSize);
long end = range.getRangeEnd(totalSize);
if (start >= totalSize) {
return ResponseEntity.status(416)
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
.header("X-Accel-Buffering", "no")
.build();
}
long regionLength = end - start + 1;
ResourceRegion region = new ResourceRegion(resource, start, regionLength);
return ResponseEntity.status(206)
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
.header("Access-Control-Expose-Headers", "Content-Disposition")
.header("X-Accel-Buffering", "no")
.header(HttpHeaders.CONTENT_RANGE, "bytes " + start + "-" + end + "/" + totalSize)
.contentLength(regionLength)
.body(region);
}
}

View File

@@ -0,0 +1,12 @@
package com.kamco.cd.training.common.download.dto;
import java.nio.file.Path;
import java.util.UUID;
import org.springframework.http.MediaType;
public record DownloadSpec(
UUID uuid, // 다운로드 식별(로그/정책용)
Path filePath, // 실제 파일 경로
String downloadName, // 사용자에게 보일 파일명
MediaType contentType // 보통 OCTET_STREAM
) {}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.common.dto;
import com.kamco.cd.training.common.enums.ModelType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
@@ -11,9 +12,14 @@ import lombok.Setter;
@AllArgsConstructor
@NoArgsConstructor
public class HyperParam {
@Schema(description = "모델", example = "G1")
private ModelType model; // G1, G2, G3
// -------------------------
// Important
// -------------------------
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone

View File

@@ -0,0 +1,27 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobStatusType implements EnumType {
QUEUED("대기중"),
RUNNING("실행중"),
SUCCESS("성공"),
FAILED("실패"),
CANCELED("취소");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum JobType implements EnumType {
TRAIN("학습"),
TEST("테스트");
private final String desc;
@Override
public String getId() {
return name();
}
@Override
public String getText() {
return desc;
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.common.enums;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import java.util.Arrays;
import lombok.AllArgsConstructor;
import lombok.Getter;
@@ -15,6 +16,13 @@ public enum ModelType implements EnumType {
private String desc;
public static ModelType getValueData(String modelNo) {
return Arrays.stream(ModelType.values())
.filter(m -> m.getId().equals(modelNo))
.findFirst()
.orElse(G1);
}
@Override
public String getId() {
return name();

View File

@@ -2,9 +2,11 @@ package com.kamco.cd.training.common.utils;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.config.api.ApiResponseDto.ApiResponseCode;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.BufferedReader;
import java.io.File;
@@ -14,6 +16,7 @@ import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
@@ -38,6 +41,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.gce.geotiff.GeoTiffReader;
import org.springframework.http.HttpStatus;
import org.springframework.util.FileSystemUtils;
import org.springframework.web.multipart.MultipartFile;
@@ -505,11 +509,15 @@ public class FIleChecker {
try {
File dir = new File(targetPath);
log.info("targetPath={}", targetPath);
log.info("absolute targetPath={}", dir.getAbsolutePath());
if (!dir.exists()) {
dir.mkdirs();
}
File dest = new File(dir, String.valueOf(chunkIndex));
log.info("real save path = {}", dest.getAbsolutePath());
log.info("chunkIndex={}, uploadSize={}", chunkIndex, mfile.getSize());
log.info("savedSize={}", dest.length());
@@ -521,6 +529,9 @@ public class FIleChecker {
log.info("after delete={}", dest.length());
mfile.transferTo(dest);
log.info("after transfer size={}", dest.length());
log.info("after transfer exists={}", dest.exists());
return true;
} catch (IOException e) {
log.error("chunk save error", e);
@@ -706,12 +717,30 @@ public class FIleChecker {
}
public static void unzip(String fileName, String destDirectory) throws IOException {
File destDir = new File(destDirectory);
if (!destDir.exists()) {
destDir.mkdirs(); // 대상 폴더가 없으면 생성
String zipFilePath = destDirectory + File.separator + fileName;
log.info("fileName : {}", fileName);
log.info("destDirectory : {}", destDirectory);
log.info("zipFilePath : {}", zipFilePath);
// zip 이름으로 폴더 생성 (확장자 제거)
String folderName =
fileName.endsWith(".zip") ? fileName.substring(0, fileName.length() - 4) : fileName;
log.info("folderName : {}", folderName);
File destDir = new File(destDirectory, folderName);
log.info("destDir : {}", destDir);
// 동일 폴더가 이미 있으면 삭제
log.info("111 destDir.exists() : {}", destDir.exists());
if (destDir.exists()) {
deleteDirectoryRecursively(destDir.toPath());
}
String zipFilePath = destDirectory + "/" + fileName;
log.info("222 destDir.exists() : {}", destDir.exists());
if (!destDir.exists()) {
log.info("mkdirs : {}", destDir.exists());
destDir.mkdirs();
}
try (ZipInputStream zis = new ZipInputStream(new FileInputStream(zipFilePath))) {
ZipEntry zipEntry = zis.getNextEntry();
@@ -744,6 +773,11 @@ public class FIleChecker {
zipEntry = zis.getNextEntry();
}
zis.closeEntry();
} catch (IOException e) {
throw new CustomApiException(
ApiResponseCode.INTERNAL_SERVER_ERROR.getId(),
HttpStatus.INTERNAL_SERVER_ERROR,
"압축 해제 중 오류가 발생했습니다: " + e.getMessage());
}
}
@@ -760,22 +794,21 @@ public class FIleChecker {
return destFile;
}
public static void uploadTo86(Path localFile) {
public static List<String> execCommandAndReadLines(String command) {
List<String> result = new ArrayList<>();
String host = "192.168.2.86";
int port = 22;
String username = "kcomu";
String user = "kcomu";
String password = "Kamco2025!";
String remoteDir = "/home/kcomu/data/request";
Session session = null;
ChannelSftp channel = null;
ChannelExec channel = null;
try {
JSch jsch = new JSch();
session = jsch.getSession(username, host, port);
session = jsch.getSession(user, host, 22);
session.setPassword(password);
Properties config = new Properties();
@@ -784,20 +817,46 @@ public class FIleChecker {
session.connect(10_000);
channel = (ChannelSftp) session.openChannel("sftp");
channel.connect(10_000);
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setInputStream(null);
// 목적지 디렉토리 이동
channel.cd(remoteDir);
InputStream in = channel.getInputStream();
channel.connect();
// 업로드
channel.put(localFile.toString(), localFile.getFileName().toString());
try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
String line;
while ((line = br.readLine()) != null) {
result.add(line);
}
}
return result;
} catch (Exception e) {
throw new RuntimeException("SFTP upload failed", e);
throw new RuntimeException("remote command failed : " + command, e);
} finally {
if (channel != null) channel.disconnect();
if (session != null) session.disconnect();
}
}
/** ✅ 폴더 재귀 삭제 */
private static void deleteDirectoryRecursively(Path path) throws IOException {
if (!Files.exists(path)) return;
// 하위부터 지워야 하므로 reverse order
Files.walk(path)
.sorted(Comparator.reverseOrder())
.forEach(
p -> {
try {
Files.deleteIfExists(p);
} catch (IOException e) {
// 여기서 바로 RuntimeException으로 올려서 상위 catch(IOException)로 잡히게 함
throw new UncheckedIOException("폴더 삭제 실패: " + p.toAbsolutePath(), e);
}
});
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.common.utils;
import jakarta.servlet.http.HttpServletRequest;
public final class HeaderUtil {
private HeaderUtil() {}
/** 특정 Header 값 조회 */
public static String get(HttpServletRequest request, String headerName) {
if (request == null || headerName == null) {
return null;
}
String value = request.getHeader(headerName);
return (value != null && !value.isBlank()) ? value : null;
}
/** 필수 Header 조회 (없으면 null) */
public static String getRequired(HttpServletRequest request, String headerName) {
return get(request, headerName);
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.config;
import java.util.concurrent.Executor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean(name = "trainJobExecutor")
public Executor trainJobExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4); // 동시에 4개 실행
executor.setMaxPoolSize(8); // 최대 8개
executor.setQueueCapacity(200); // 대기 큐
executor.setThreadNamePrefix("train-job-");
executor.initialize();
return executor;
}
}

View File

@@ -76,11 +76,13 @@ public class SecurityConfig {
"/api/auth/logout",
"/swagger-ui/**",
"/v3/api-docs/**",
"/api/members/*/password",
"/api/upload/chunk-upload-dataset",
"/api/upload/chunk-upload-complete")
"/api/upload/chunk-upload-complete",
"/download_progress_test.html",
"/api/models/download/**")
.permitAll()
.requestMatchers("/api/members/*/password")
.authenticated()
// default
.anyRequest()
.authenticated())

View File

@@ -5,11 +5,14 @@ import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.menu.dto.MenuDto;
import jakarta.servlet.http.HttpServletRequest;
import java.io.UnsupportedEncodingException;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.util.ContentCachingRequestWrapper;
@Slf4j
public class ApiLogFunction {
// 클라이언트 IP 추출
@@ -34,6 +37,14 @@ public class ApiLogFunction {
return ip;
}
public static String getXFowardedForIp(HttpServletRequest request) {
String ip = request.getHeader("X-Forwarded-For");
if (ip != null) {
ip = ip.split(",")[0].trim();
}
return ip;
}
// 사용자 ID 추출 예시 (Spring Security 기준)
public static String getUserId(HttpServletRequest request) {
try {
@@ -47,20 +58,20 @@ public class ApiLogFunction {
String method = request.getMethod().toUpperCase();
String uri = request.getRequestURI().toLowerCase();
// URL 기반 DOWNLOAD/PRINT 분류
// URL 기반 DOWNLOAD/PRINT 분류 -> /download는 FileDownloadInterceptor로 옮김
if (uri.contains("/download") || uri.contains("/export")) {
return EventType.DOWNLOAD;
}
if (uri.contains("/print")) {
return EventType.PRINT;
return EventType.OTHER;
}
// 일반 CRUD
return switch (method) {
case "POST" -> EventType.CREATE;
case "GET" -> EventType.READ;
case "DELETE" -> EventType.DELETE;
case "PUT", "PATCH" -> EventType.UPDATE;
case "POST" -> EventType.ADDED;
case "GET" -> EventType.LIST;
case "DELETE" -> EventType.REMOVE;
case "PUT", "PATCH" -> EventType.MODIFIED;
default -> EventType.OTHER;
};
}
@@ -121,12 +132,22 @@ public class ApiLogFunction {
public static String getUriMenuInfo(List<MenuDto.Basic> menuList, String uri) {
MenuDto.Basic m =
String normalizedUri = uri.replace("/api", "");
MenuDto.Basic basic =
menuList.stream()
.filter(menu -> menu.getMenuApiUrl() != null && uri.contains(menu.getMenuApiUrl()))
.findFirst()
.filter(
menu -> menu.getMenuUrl() != null && normalizedUri.startsWith(menu.getMenuUrl()))
.max(Comparator.comparingInt(m -> m.getMenuUrl().length()))
.orElse(null);
return m != null ? m.getMenuUid() : "SYSTEM";
return basic != null ? basic.getMenuUid() : "SYSTEM";
}
public static String cutRequestBody(String value) {
int MAX_LEN = 255;
if (value == null) {
return null;
}
return value.length() <= MAX_LEN ? value : value.substring(0, MAX_LEN);
}
}

View File

@@ -2,10 +2,17 @@ package com.kamco.cd.training.config.api;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.auth.CustomUserDetails;
import com.kamco.cd.training.common.utils.HeaderUtil;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.menu.dto.MenuDto;
import com.kamco.cd.training.menu.service.MenuService;
import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.kamco.cd.training.postgres.repository.log.AuditLogRepository;
import jakarta.servlet.http.HttpServletRequest;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Optional;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.MethodParameter;
import org.springframework.http.MediaType;
@@ -23,6 +30,7 @@ import org.springframework.web.util.ContentCachingRequestWrapper;
*
* <p>createOK() → 201 CREATED ok() → 200 OK deleteOk() → 204 NO_CONTENT
*/
@Slf4j
@RestControllerAdvice
public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
@@ -61,12 +69,27 @@ public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
if (body instanceof ApiResponseDto<?> apiResponse) {
response.setStatusCode(apiResponse.getHttpStatus());
String ip = ApiLogFunction.getClientIp(servletRequest);
Long userid = null;
String actionType = HeaderUtil.get(servletRequest, "kamco-action-type");
// actionType 이 없으면 로그 저장하지 않기 || download 는 FileDownloadInterceptor 에서 하기
// (file down URL prefix 추가는 WebConfig.java 에 하기)
if (actionType == null || actionType.equalsIgnoreCase("download")) {
return body;
}
if (servletRequest.getUserPrincipal() instanceof UsernamePasswordAuthenticationToken auth
&& auth.getPrincipal() instanceof CustomUserDetails customUserDetails) {
userid = customUserDetails.getMember().getId();
String ip =
Optional.ofNullable(HeaderUtil.get(servletRequest, "kamco-user-ip"))
.orElseGet(() -> ApiLogFunction.getXFowardedForIp(servletRequest));
Long userid = null;
String loginAttemptId = null;
// 로그인 시도할 때
if (servletRequest.getRequestURI().contains("/api/auth/signin")) {
loginAttemptId = HeaderUtil.get(servletRequest, "kamco-login-attempt-id");
} else {
if (servletRequest.getUserPrincipal() instanceof UsernamePasswordAuthenticationToken auth
&& auth.getPrincipal() instanceof CustomUserDetails customUserDetails) {
userid = customUserDetails.getMember().getId();
}
}
String requestBody;
@@ -84,17 +107,33 @@ public class ApiResponseAdvice implements ResponseBodyAdvice<Object> {
requestBody = maskSensitiveFields(requestBody);
}
List<?> list = menuService.getFindAll();
List<MenuDto.Basic> result =
list.stream()
.map(
item -> {
if (item instanceof LinkedHashMap<?, ?> map) {
return objectMapper.convertValue(map, MenuDto.Basic.class);
} else if (item instanceof MenuDto.Basic dto) {
return dto;
} else {
throw new IllegalStateException("Unsupported cache type: " + item.getClass());
}
})
.toList();
AuditLogEntity log =
new AuditLogEntity(
userid,
ApiLogFunction.getEventType(servletRequest),
EventType.fromName(actionType),
ApiLogFunction.isSuccessFail(apiResponse),
ApiLogFunction.getUriMenuInfo(
menuService.getFindAll(), servletRequest.getRequestURI()),
ApiLogFunction.getUriMenuInfo(result, servletRequest.getRequestURI()),
ip,
servletRequest.getRequestURI(),
requestBody,
apiResponse.getErrorLogUid());
ApiLogFunction.cutRequestBody(requestBody),
apiResponse.getErrorLogUid(),
null,
loginAttemptId);
auditLogRepository.save(log);
}

View File

@@ -14,15 +14,15 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import java.io.IOException;
import java.nio.file.FileStore;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource;
import org.springframework.data.domain.Page;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
@@ -212,8 +212,15 @@ public class DatasetApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/usable-bytes")
public ApiResponseDto<DatasetStorage> getUsableBytes() {
return ApiResponseDto.ok(datasetService.getUsableBytes());
public ApiResponseDto<DatasetStorage> getUsableBytes() throws IOException {
FileStore store = Files.getFileStore(Path.of("."));
long usable = store.getUsableSpace();
DatasetStorage storage = new DatasetStorage();
storage.setUsableBytes(String.valueOf(usable));
// datasetService.getUsableBytes();
return ApiResponseDto.ok(storage);
}
@Operation(summary = "학습데이터 zip파일 등록", description = "학습데이터 zip파일 등록 합니다.")
@@ -221,7 +228,7 @@ public class DatasetApiController {
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
@RequestBody @Valid DatasetDto.AddReq addReq) {
return ApiResponseDto.ok(datasetService.insertDataset(addReq));
return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
}
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
@@ -230,10 +237,15 @@ public class DatasetApiController {
throws Exception {
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
Path filePath = Paths.get(path);
return datasetService.getFilePathByFile(path);
}
Resource resource = new UrlResource(filePath.toUri());
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
@GetMapping("/files-to86")
public ResponseEntity<Resource> getFileTo86(
@RequestParam UUID uuid, @RequestParam String pathType) throws Exception {
return ResponseEntity.ok().contentType(MediaType.APPLICATION_OCTET_STREAM).body(resource);
String path = datasetService.getFilePathByUUIDPathType(uuid, pathType);
return datasetService.getFilePathByFile(path);
}
}

View File

@@ -1,7 +1,6 @@
package com.kamco.cd.training.dataset.dto;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.kamco.cd.training.common.enums.LearnDataRegister;
import com.kamco.cd.training.common.enums.LearnDataType;
import com.kamco.cd.training.common.enums.ModelType;
@@ -77,9 +76,16 @@ public class DatasetDto {
}
public String getTotalSize(Long totalSize) {
if (totalSize == null) return "0G";
if (totalSize == null || totalSize <= 0) return "0M";
double giga = totalSize / (1024.0 * 1024 * 1024);
return String.format("%.2fG", giga);
if (giga >= 1) {
return String.format("%.2fG", giga);
} else {
double mega = totalSize / (1024.0 * 1024);
return String.format("%.2fM", mega);
}
}
public String getStatus(String status) {
@@ -227,7 +233,6 @@ public class DatasetDto {
@Getter
@Setter
@NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class SelectDataSet {
private String modelNo; // G1, G2, G3 모델 타입
@@ -261,6 +266,7 @@ public class DatasetDto {
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
@@ -309,6 +315,183 @@ public class DatasetDto {
}
}
@Schema(name = "SelectTransferDataSet", description = "전이학습 데이터셋 선택 리스트")
@Getter
@Setter
@NoArgsConstructor
public static class SelectTransferDataSet {
private String modelNo; // G1, G2, G3 모델 타입
private Long datasetId;
private UUID uuid;
private String dataType;
private String title;
private Long roundNo;
private Integer compareYyyy;
private Integer targetYyyy;
private String memo;
@JsonIgnore private Long classCount;
private Integer buildingCnt;
private Integer containerCnt;
private String dataTypeName;
private Long wasteCnt;
private Long landCoverCnt;
private String beforeModelNo; // G1, G2, G3 모델 타입
private Long beforeDatasetId;
private UUID beforeUuid;
private String beforeDataType;
private String beforeTitle;
private Long beforeRoundNo;
private Integer beforeCompareYyyy;
private Integer beforeTargetYyyy;
private String beforeMemo;
@JsonIgnore private Long beforeClassCount;
private Integer beforeBuildingCnt;
private Integer beforeContainerCnt;
private String beforeDataTypeName;
private Long beforeWasteCnt;
private Long beforeLandCoverCnt;
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Long classCount,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Long beforeClassCount) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.classCount = classCount;
if (modelNo != null && modelNo.equals(ModelType.G2.getId())) {
this.wasteCnt = classCount;
} else if (modelNo != null && modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount;
}
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeClassCount = beforeClassCount;
if (beforeModelNo != null && beforeModelNo.equals(ModelType.G2.getId())) {
this.beforeWasteCnt = beforeClassCount;
} else if (beforeModelNo != null && beforeModelNo.equals(ModelType.G3.getId())) {
this.beforeLandCoverCnt = beforeClassCount;
}
}
public SelectTransferDataSet(
// 현재
String modelNo,
Long datasetId,
UUID uuid,
String dataType,
String title,
Long roundNo,
Integer compareYyyy,
Integer targetYyyy,
String memo,
Integer buildingCnt,
Integer containerCnt,
// 이전(before)
String beforeModelNo,
Long beforeDatasetId,
UUID beforeUuid,
String beforeDataType,
String beforeTitle,
Long beforeRoundNo,
Integer beforeCompareYyyy,
Integer beforeTargetYyyy,
String beforeMemo,
Integer beforeBuildingCnt,
Integer beforeContainerCnt) {
// 현재
this.modelNo = modelNo;
this.datasetId = datasetId;
this.uuid = uuid;
this.dataType = dataType;
this.dataTypeName = getDataTypeName(dataType);
this.title = title;
this.roundNo = roundNo;
this.compareYyyy = compareYyyy;
this.targetYyyy = targetYyyy;
this.memo = memo;
this.buildingCnt = buildingCnt;
this.containerCnt = containerCnt;
// 이전(before)
this.beforeModelNo = beforeModelNo;
this.beforeDatasetId = beforeDatasetId;
this.beforeUuid = beforeUuid;
this.beforeDataType = beforeDataType;
this.beforeDataTypeName = getDataTypeName(beforeDataType);
this.beforeTitle = beforeTitle;
this.beforeRoundNo = beforeRoundNo;
this.beforeCompareYyyy = beforeCompareYyyy;
this.beforeTargetYyyy = beforeTargetYyyy;
this.beforeMemo = beforeMemo;
this.beforeBuildingCnt = beforeBuildingCnt;
this.beforeContainerCnt = beforeContainerCnt;
}
public String getDataTypeName(String groupTitleCd) {
LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd);
return type == null ? null : type.getText();
}
public String getYear() {
return this.compareYyyy + "-" + this.targetYyyy;
}
public String getBeforeYear() {
if (this.beforeCompareYyyy == null || this.beforeTargetYyyy == null) {
return null;
}
return this.beforeCompareYyyy + "-" + this.beforeTargetYyyy;
}
}
@Getter
@Setter
@NoArgsConstructor

View File

@@ -21,20 +21,30 @@ import com.kamco.cd.training.dataset.dto.DatasetObjDto.SearchReq;
import com.kamco.cd.training.postgres.core.DatasetCoreService;
import jakarta.validation.Valid;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -50,6 +60,8 @@ public class DatasetService {
private String datasetDir;
private static final List<String> LABEL_DIRS = List.of("label-json", "label", "input1", "input2");
private static final List<String> REQUIRED_DIRS = Arrays.asList("train", "val", "test");
private static final List<String> CHECK_DIRS = List.of("label", "input1", "input2");
/**
* 데이터셋 목록 조회
@@ -164,9 +176,22 @@ public class DatasetService {
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
try {
// 같은 uid 로 등록한 파일이 있는지 확인
Long existsCnt =
datasetCoreService.findDatasetByUidExistsCnt(addReq.getFileName().replace(".zip", ""));
if (existsCnt > 0) {
return new ResponseObj(ApiResponseCode.DUPLICATE_DATA, "이미 등록된 회차 데이터 파일입니다. 확인 부탁드립니다.");
}
// 압축 해제
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());
// 압축 해제한 폴더 하위에 train,val,test 폴더 모두 존재하는지 확인
validateTrainValTestDirs(addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 압축 해제한 폴더의 갯수 맞는지 log 찍기
validateDirFileCount(addReq.getFilePath() + addReq.getFileName().replace(".zip", ""));
// 해제한 폴더 읽어서 데이터 저장
List<Map<String, Object>> list =
getUnzipDatasetFiles(
@@ -179,6 +204,17 @@ public class DatasetService {
idx++;
}
List<Map<String, Object>> valList =
getUnzipDatasetFiles(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "val");
int valIdx = 0;
for (Map<String, Object> valid : valList) {
datasetUid =
this.insertTrainTestData(valid, addReq, valIdx, datasetUid, "val"); // val 데이터 insert
valIdx++;
}
List<Map<String, Object>> testList =
getUnzipDatasetFiles(
addReq.getFilePath() + addReq.getFileName().replace(".zip", ""), "test");
@@ -285,6 +321,8 @@ public class DatasetService {
if (subDir.equals("train")) {
datasetCoreService.insertDatasetObj(objRegDto);
} else if (subDir.equals("val")) {
datasetCoreService.insertDatasetValObj(objRegDto);
} else {
datasetCoreService.insertDatasetTestObj(objRegDto);
}
@@ -303,7 +341,10 @@ public class DatasetService {
Path dir = root.resolve(dirName);
if (!Files.isDirectory(dir)) {
throw new IllegalStateException("폴더가 존재하지 않습니다 : " + dir);
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"폴더가 존재하지 않습니다. 업로드 된 파일을 확인하세요. : " + dir);
}
try (Stream<Path> stream = Files.list(dir)) {
@@ -356,4 +397,131 @@ public class DatasetService {
public String getFilePathByUUIDPathType(UUID uuid, String pathType) {
return datasetCoreService.getFilePathByUUIDPathType(uuid, pathType);
}
private String readRemoteFileAsString(String remoteFilePath) {
String command = "cat " + escape(remoteFilePath);
List<String> lines = FIleChecker.execCommandAndReadLines(command);
return String.join("\n", lines);
}
private JsonNode parseJson(String json) {
try {
ObjectMapper mapper = new ObjectMapper();
return mapper.readTree(json);
} catch (IOException e) {
throw new RuntimeException("JSON 파싱 실패", e);
}
}
private String escape(String path) {
// 쉘 커맨드에서 안전하게 사용할 수 있도록 문자열을 작은따옴표로 감싸면서, 내부의 작은따옴표를 이스케이프 처리
return "'" + path.replace("'", "'\"'\"'") + "'";
}
private static String normalizeLinuxPath(String path) {
return path.replace("\\", "/");
}
public ResponseEntity<Resource> getFilePathByFile(String remoteFilePath) {
try {
Path path = Paths.get(remoteFilePath);
InputStream inputStream = Files.newInputStream(path);
InputStreamResource resource =
new InputStreamResource(inputStream) {
@Override
public long contentLength() {
return -1; // 알 수 없으면 -1
}
};
String fileName = Paths.get(remoteFilePath.replace("\\", "/")).getFileName().toString();
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/** unzipRootDir: 압축 해제된 폴더 경로 (ex: /data/xxx/myzipname) */
public static void validateTrainValTestDirs(String unzipRootDir) {
Path root = Paths.get(unzipRootDir);
// 루트 폴더 자체 존재 확인
if (!Files.exists(root) || !Files.isDirectory(root)) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"압축 해제 폴더가 존재하지 않습니다: " + unzipRootDir);
}
// 필요한 폴더들 존재/디렉토리 여부 확인
List<String> missing =
REQUIRED_DIRS.stream()
.filter(d -> !Files.isDirectory(root.resolve(d)))
.collect(Collectors.toList());
if (!missing.isEmpty()) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
"데이터 폴더 구조가 올바르지 않습니다. 누락된 폴더: "
+ String.join(", ", missing)
+ " (필수: train, val, test)");
}
}
public static void validateDirFileCount(String unzipRootDir) {
Path root = Paths.get(unzipRootDir);
for (String split : REQUIRED_DIRS) {
Path splitPath = root.resolve(split);
Map<String, Long> fileCountMap = new HashMap<>();
for (String subDir : CHECK_DIRS) { // input1, input2, label 폴더만 수행하기
Path subDirPath = splitPath.resolve(subDir);
if (!Files.isDirectory(subDirPath)) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + " 폴더 하위에 " + subDir + " 폴더가 존재하지 않습니다.");
}
long count;
try (Stream<Path> files = Files.list(subDirPath)) {
count = files.filter(Files::isRegularFile).count();
log.info("dir: " + subDirPath + ", count: " + count);
} catch (IOException e) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + "/" + subDir + " 파일 개수 확인 중 오류 발생");
}
fileCountMap.put(subDir, count);
}
// 모든 폴더 파일 개수가 동일한지 확인
Set<Long> uniqueCounts = new HashSet<>(fileCountMap.values());
if (uniqueCounts.size() != 1) {
throw new CustomApiException(
ApiResponseCode.NOT_FOUND_DATA.getId(),
HttpStatus.CONFLICT,
split + " 데이터 파일 개수가 일치하지 않습니다. " + fileCountMap.toString());
}
}
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.hyperparam;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
@@ -65,7 +66,7 @@ public class HyperParamApiController {
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 수정 불가", content = @Content),
@ApiResponse(responseCode = "422", description = "default는 삭제불가", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PutMapping("/{uuid}")
@@ -96,10 +97,13 @@ public class HyperParamApiController {
String type,
@Parameter(description = "시작일", example = "2026-02-01") @RequestParam(required = false)
LocalDate startDate,
@Parameter(description = "종료일", example = "2026-02-28") @RequestParam(required = false)
@Parameter(description = "종료일", example = "2026-03-31") @RequestParam(required = false)
LocalDate endDate,
@Parameter(description = "버전명", example = "HPs_0001") @RequestParam(required = false)
@Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false)
String hyperVer,
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
@RequestParam(required = false)
ModelType model,
@Parameter(
description = "정렬",
example = "createdDttm desc",
@@ -124,7 +128,7 @@ public class HyperParamApiController {
searchReq.setSort(sort);
searchReq.setPage(page);
searchReq.setSize(size);
Page<List> list = hyperParamService.getHyperParamList(searchReq);
Page<List> list = hyperParamService.getHyperParamList(model, searchReq);
return ApiResponseDto.ok(list);
}
@@ -133,12 +137,12 @@ public class HyperParamApiController {
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "422", description = "HPs_0001 삭제 불가", content = @Content),
@ApiResponse(responseCode = "422", description = "default 삭제 불가", content = @Content),
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@PathVariable
UUID uuid) {
hyperParamService.deleteHyperParam(uuid);
@@ -160,7 +164,7 @@ public class HyperParamApiController {
})
@GetMapping("/{uuid}")
public ApiResponseDto<HyperParamDto.Basic> getHyperParam(
@Parameter(description = "하이퍼파라미터 uuid", example = "c3b5a285-8f68-42af-84f0-e6d09162deb5")
@Parameter(description = "하이퍼파라미터 uuid", example = "57fc9170-64c1-4128-aa7b-0657f08d6d10")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(hyperParamService.getHyperParam(uuid));
@@ -179,8 +183,9 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/init")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam() {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam());
@GetMapping("/init/{model}")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.hyperparam.dto;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
@@ -24,9 +25,12 @@ public class HyperParamDto {
@AllArgsConstructor
public static class Basic {
private ModelType model; // 20250212 modeltype추가
private UUID uuid;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createdDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Integer totalUseCnt;
// -------------------------
// Important
@@ -98,6 +102,8 @@ public class HyperParamDto {
private Integer gpuCnt;
private String gpuIds;
private Integer masterPort;
private Boolean isDefault;
}
@Getter
@@ -106,13 +112,12 @@ public class HyperParamDto {
@AllArgsConstructor
public static class List {
private UUID uuid;
private ModelType model;
private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
private Long totalCnt;
private String memo;
private Integer totalUseCnt;
}
@Getter

View File

@@ -1,8 +1,10 @@
package com.kamco.cd.training.hyperparam.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.List;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -20,11 +22,12 @@ public class HyperParamService {
/**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return 목록
*/
public Page<List> getHyperParamList(HyperParamDto.SearchReq req) {
return hyperParamCoreService.findByHyperVerList(req);
public Page<List> getHyperParamList(ModelType model, SearchReq req) {
return hyperParamCoreService.findByHyperVerList(model, req);
}
/**
@@ -59,8 +62,8 @@ public class HyperParamService {
}
/** 하이퍼파라미터 최적화 설정값 조회 */
public HyperParamDto.Basic getInitHyperParam() {
return hyperParamCoreService.getInitHyperParam();
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
return hyperParamCoreService.getInitHyperParam(model);
}
/**

View File

@@ -0,0 +1,99 @@
package com.kamco.cd.training.log;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.service.AuditLogService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.LocalDate;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "감사 로그", description = "감사 로그 관리 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/logs/audit")
public class AuditLogApiController {
private final AuditLogService auditLogService;
@Operation(summary = "일자별 로그 조회")
@GetMapping("/daily")
public ApiResponseDto<Page<AuditLogDto.DailyAuditList>> getDailyLogs(
@RequestParam(required = false) LocalDate startDate,
@RequestParam(required = false) LocalDate endDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.DailyAuditList> result =
auditLogService.getLogByDaily(searchReq, startDate, endDate);
return ApiResponseDto.ok(result);
}
@Operation(summary = "일자별 로그 상세")
@GetMapping("/daily/result")
public ApiResponseDto<Page<AuditLogDto.DailyDetail>> getDailyResultLogs(
@RequestParam LocalDate logDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.DailyDetail> result = auditLogService.getLogByDailyResult(searchReq, logDate);
return ApiResponseDto.ok(result);
}
@Operation(summary = "메뉴별 로그 조회")
@GetMapping("/menu")
public ApiResponseDto<Page<AuditLogDto.MenuAuditList>> getMenuLogs(
@RequestParam(required = false) String searchValue,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.MenuAuditList> result = auditLogService.getLogByMenu(searchReq, searchValue);
return ApiResponseDto.ok(result);
}
@Operation(summary = "메뉴별 로그 상세")
@GetMapping("/menu/result")
public ApiResponseDto<Page<AuditLogDto.MenuDetail>> getMenuResultLogs(
@RequestParam String menuId,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.MenuDetail> result = auditLogService.getLogByMenuResult(searchReq, menuId);
return ApiResponseDto.ok(result);
}
@Operation(summary = "사용자별 로그 조회")
@GetMapping("/account")
public ApiResponseDto<Page<AuditLogDto.UserAuditList>> getAccountLogs(
@RequestParam(required = false) String searchValue,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.UserAuditList> result =
auditLogService.getLogByAccount(searchReq, searchValue);
return ApiResponseDto.ok(result);
}
@Operation(summary = "사용자별 로그 상세")
@GetMapping("/account/result")
public ApiResponseDto<Page<AuditLogDto.UserDetail>> getAccountResultLogs(
@RequestParam Long userUid,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
AuditLogDto.searchReq searchReq = new AuditLogDto.searchReq(page, size, "created_dttm,desc");
Page<AuditLogDto.UserDetail> result = auditLogService.getLogByAccountResult(searchReq, userUid);
return ApiResponseDto.ok(result);
}
}

View File

@@ -0,0 +1,40 @@
package com.kamco.cd.training.log;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.log.service.ErrorLogService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.LocalDate;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "에러 로그", description = "에러 로그 관리 API")
@RequiredArgsConstructor
@RestController
@RequestMapping({"/api/logs/system"})
public class ErrorLogApiController {
private final ErrorLogService errorLogService;
@Operation(summary = "에러로그 조회")
@GetMapping("/error")
public ApiResponseDto<Page<ErrorLogDto.Basic>> getErrorLogs(
@RequestParam(required = false) ErrorLogDto.LogErrorLevel logErrorLevel,
@RequestParam(required = false) EventType eventType,
@RequestParam(required = false) LocalDate startDate,
@RequestParam(required = false) LocalDate endDate,
@RequestParam int page,
@RequestParam(defaultValue = "20") int size) {
ErrorLogDto.ErrorSearchReq searchReq =
new ErrorLogDto.ErrorSearchReq(
logErrorLevel, eventType, startDate, endDate, page, size, "created_dttm,desc");
Page<ErrorLogDto.Basic> result = errorLogService.findLogByError(searchReq);
return ApiResponseDto.ok(result);
}
}

View File

@@ -3,7 +3,9 @@ package com.kamco.cd.training.log.dto;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
@@ -58,6 +60,7 @@ public class AuditLogDto {
@Getter
@AllArgsConstructor
public static class AuditCommon {
private int readCount;
private int cudCount;
private int printCount;
@@ -68,6 +71,7 @@ public class AuditLogDto {
@Schema(name = "DailyAuditList", description = "일자별 목록")
@Getter
public static class DailyAuditList extends AuditCommon {
private final String baseDate;
public DailyAuditList(
@@ -85,6 +89,7 @@ public class AuditLogDto {
@Schema(name = "MenuAuditList", description = "메뉴별 목록")
@Getter
public static class MenuAuditList extends AuditCommon {
private final String menuId;
private final String menuName;
@@ -105,6 +110,7 @@ public class AuditLogDto {
@Schema(name = "UserAuditList", description = "사용자별 목록")
@Getter
public static class UserAuditList extends AuditCommon {
private final Long accountId;
private final String loginId;
private final String username;
@@ -129,6 +135,7 @@ public class AuditLogDto {
@Getter
@AllArgsConstructor
public static class AuditDetail {
private Long logId;
private EventType eventType;
private LogDetail detail;
@@ -137,9 +144,11 @@ public class AuditLogDto {
@Schema(name = "DailyDetail", description = "일자별 로그 상세")
@Getter
public static class DailyDetail extends AuditDetail {
private final String userName;
private final String loginId;
private final String menuName;
private final String logDateTime;
public DailyDetail(
Long logId,
@@ -147,17 +156,20 @@ public class AuditLogDto {
String loginId,
String menuName,
EventType eventType,
String logDateTime,
LogDetail detail) {
super(logId, eventType, detail);
this.userName = userName;
this.loginId = loginId;
this.menuName = menuName;
this.logDateTime = logDateTime;
}
}
@Schema(name = "MenuDetail", description = "메뉴별 로그 상세")
@Getter
public static class MenuDetail extends AuditDetail {
private final String logDateTime;
private final String userName;
private final String loginId;
@@ -179,6 +191,7 @@ public class AuditLogDto {
@Schema(name = "UserDetail", description = "사용자별 로그 상세")
@Getter
public static class UserDetail extends AuditDetail {
private final String logDateTime;
private final String menuNm;
@@ -194,6 +207,7 @@ public class AuditLogDto {
@Setter
@AllArgsConstructor
public static class LogDetail {
String serviceName;
String parentMenuName;
String menuName;
@@ -226,4 +240,26 @@ public class AuditLogDto {
return PageRequest.of(page, size);
}
}
@Getter
@Setter
public static class DownloadReq {
UUID uuid;
LocalDate startDate;
LocalDate endDate;
String searchValue;
String menuId;
String requestUri;
}
@Getter
@Setter
@AllArgsConstructor
public static class DownloadRes {
String name;
String employeeNo;
@JsonFormatDttm ZonedDateTime downloadDttm;
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.log.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDate;
@@ -77,6 +78,7 @@ public class ErrorLogDto {
}
}
@CodeExpose
public enum LogErrorLevel implements EnumType {
WARNING("Warning"),
ERROR("Error"),

View File

@@ -1,22 +1,35 @@
package com.kamco.cd.training.log.dto;
import com.kamco.cd.training.common.utils.enums.CodeExpose;
import com.kamco.cd.training.common.utils.enums.EnumType;
import lombok.AllArgsConstructor;
import lombok.Getter;
@CodeExpose
@Getter
@AllArgsConstructor
public enum EventType implements EnumType {
CREATE("생성"),
READ("조회"),
UPDATE("수정"),
DELETE("삭제"),
LIST("목록"),
DETAIL("상세"),
POPUP("팝업"),
STATUS("상태"),
ADDED("추가"),
MODIFIED("수정"),
REMOVE("삭제"),
DOWNLOAD("다운로드"),
PRINT("출력"),
LOGIN("로그인"),
OTHER("기타");
private final String desc;
public static EventType fromName(String name) {
try {
return EventType.valueOf(name.toUpperCase());
} catch (Exception e) {
return OTHER;
}
}
@Override
public String getId() {
return name();

View File

@@ -1,22 +1,39 @@
package com.kamco.cd.training.model;
import com.kamco.cd.training.common.download.RangeDownloadResponder;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.model.service.ModelTrainDetailService;
import com.kamco.cd.training.model.service.ModelTrainMngService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.apache.coyote.BadRequestException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
@@ -28,6 +45,11 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/models")
public class ModelTrainDetailApiController {
private final ModelTrainDetailService modelTrainDetailService;
private final ModelTrainMngService modelTrainMngService;
private final RangeDownloadResponder rangeDownloadResponder;
@Value("${train.docker.responseDir}")
private String responseDir;
@Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 학습 진행정보", description = "학습 진행정보, 모델학습 정보 API")
@ApiResponses(
@@ -112,7 +134,28 @@ public class ModelTrainDetailApiController {
return ApiResponseDto.ok(modelTrainDetailService.getByModelMappingDataset(uuid));
}
@Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
// @Operation(summary = "모델관리 > 전이 학습 실행설정 > 모델선택", description = "모델선택 정보 API")
// @ApiResponses(
// value = {
// @ApiResponse(
// responseCode = "200",
// description = "조회 성공",
// content =
// @Content(
// mediaType = "application/json",
// schema = @Schema(implementation = TransferDetailDto.class))),
// @ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
// @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
// })
// @GetMapping("/transfer/detail/{uuid}")
// public ApiResponseDto<TransferDetailDto> getTransferDetail(
// @Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
// @PathVariable
// UUID uuid) {
// return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
// }
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Train)", description = "모델 상세 > 성능 정보 (Train) API")
@ApiResponses(
value = {
@ApiResponse(
@@ -125,11 +168,162 @@ public class ModelTrainDetailApiController {
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/transfer/detail/{uuid}")
public ApiResponseDto<TransferDetailDto> getTransferDetail(
@Parameter(description = "모델 uuid", example = "7fbdff54-ea87-4b02-90d1-955fa2a3457e")
@GetMapping("/metrics/train/{uuid}")
public ApiResponseDto<List<ModelTrainMetrics>> getModelTrainMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getTransferDetail(uuid));
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainMetricResult(uuid));
}
@Operation(
summary = "모델관리 > 모델 상세 > 성능 정보 (Validation)",
description = "모델 상세 > 성능 정보 (Validation) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/metrics/validation/{uuid}")
public ApiResponseDto<List<ModelValidationMetrics>> getModelValidationMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelValidationMetricResult(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/metrics/test/{uuid}")
public ApiResponseDto<List<ModelTestMetrics>> getModelTestMetricResult(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTestMetricResult(uuid));
}
@Operation(summary = "모델관리 > 모델 상세 > 성능 정보 (Test)", description = "모델 상세 > 성능 정보 (Test) API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/best-epoch/{uuid}")
public ApiResponseDto<ModelBestEpoch> getModelTrainBestEpoch(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
}
@Operation(
summary = "학습데이터 파일 다운로드",
description = "학습데이터 파일 다운로드",
parameters = {
@Parameter(
name = "kamco-download-uuid",
in = ParameterIn.HEADER,
required = true,
description = "다운로드 요청 UUID",
schema =
@Schema(
type = "string",
format = "uuid",
example = "6d8d49dc-0c9d-4124-adc7-b9ca610cc394"))
})
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "학습데이터 zip파일 다운로드",
content =
@Content(
mediaType = "application/octet-stream",
schema = @Schema(type = "string", format = "binary"))),
@ApiResponse(responseCode = "404", description = "파일 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/download/{uuid}")
public ResponseEntity<?> download(@PathVariable UUID uuid, HttpServletRequest request)
throws IOException {
Basic info = modelTrainDetailService.findByModelByUUID(uuid);
Path zipPath =
Paths.get(responseDir)
.resolve(String.valueOf(info.getUuid()))
.resolve(info.getModelVer() + ".zip");
if (!Files.isRegularFile(zipPath)) {
throw new BadRequestException();
}
return rangeDownloadResponder.buildZipResponse(zipPath, info.getModelVer() + ".zip", request);
}
@Operation(summary = "모델관리 > 모델 상세 > 파일 정보", description = "모델 상세 > 파일 정보 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = TransferDetailDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/file-info/{uuid}")
public ApiResponseDto<ModelFileInfo> getModelTrainFileInfo(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainFileInfo(uuid));
}
@Operation(summary = "모델관리 > 모델별 진행 상황", description = "모델관리 > 모델별 진행 상황 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelProgressStepDto.class))),
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/progress/{uuid}")
public ApiResponseDto<List<ModelProgressStepDto>> findModelTrainProgressInfo(
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(modelTrainDetailService.findModelTrainProgressInfo(uuid));
}
}

View File

@@ -6,8 +6,10 @@ import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.model.service.ModelTrainMngService;
import com.kamco.cd.training.train.service.ModelTestMetricsJobService;
import com.kamco.cd.training.train.service.ModelTrainMetricsJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
@@ -16,6 +18,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -35,6 +38,8 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/models")
public class ModelTrainMngApiController {
private final ModelTrainMngService modelTrainMngService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
@Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API")
@ApiResponses(
@@ -50,7 +55,7 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/list")
public ApiResponseDto<Page<Basic>> findByModelList(
public ApiResponseDto<Page<ListDto>> findByModelList(
@Parameter(
description = "상태코드",
example = "IN_PROGRESS",
@@ -74,7 +79,7 @@ public class ModelTrainMngApiController {
@ApiResponses(
value = {
@ApiResponse(responseCode = "200", description = "삭제 성공", content = @Content),
@ApiResponse(responseCode = "409", description = "HPs_0001 삭제 불가", content = @Content)
@ApiResponse(responseCode = "409", description = "G1_000001 삭제 불가", content = @Content)
})
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteModelTrain(
@@ -92,9 +97,8 @@ public class ModelTrainMngApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping
public ApiResponseDto<String> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
modelTrainMngService.createModelTrain(req);
return ApiResponseDto.ok("ok");
public ApiResponseDto<UUID> createModelTrain(@Valid @RequestBody ModelTrainMngDto.AddReq req) {
return ApiResponseDto.ok(modelTrainMngService.createModelTrain(req));
}
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
@@ -150,4 +154,64 @@ public class ModelTrainMngApiController {
req.setDataType(selectType);
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
}
@Operation(
summary = "모델학습 1단계/2단계 실행중인 것이 있는지 count",
description = "모델학습 1단계/2단계 실행중인 것이 있는지 count")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/ing-training-cnt")
public ApiResponseDto<Long> findModelStep1InProgressCnt() {
return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt());
}
@Operation(
summary = "스케줄러 findTrainValidMetricCsvFiles",
description = "스케줄러 findTrainValidMetricCsvFiles")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/schedule-trainvalid")
public ApiResponseDto<Long> findTrainValidMetricCsvFiles() {
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
return ApiResponseDto.ok(null);
}
@Operation(summary = "스케줄러 findTestMetricCsvFiles", description = "스케줄러 findTestMetricCsvFiles")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/schedule-test")
public ApiResponseDto<Long> findTestValidMetricCsvFiles() throws IOException {
modelTestMetricsJobService.findTestValidMetricCsvFiles();
return ApiResponseDto.ok(null);
}
}

View File

@@ -20,4 +20,25 @@ public class ModelConfigDto {
private Float testPercent;
private String memo;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class TransferBasic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
private Long beforeConfigId;
private Long beforeModelId;
private Integer beforeEpochCount;
private Float beforeTrainPercent;
private Float beforeValidationPercent;
private Float beforeTestPercent;
private String beforeMemo;
}
}

View File

@@ -6,7 +6,7 @@ import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.enums.Enums;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.Duration;
import java.time.ZonedDateTime;
@@ -35,6 +35,7 @@ public class ModelTrainDetailDto {
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String statusCd;
private String trainType;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -176,8 +177,83 @@ public class ModelTrainDetailDto {
@NoArgsConstructor
@AllArgsConstructor
public static class TransferDetailDto {
private ModelConfigDto.Basic etcConfig;
private ModelConfigDto.TransferBasic etcConfig;
private TransferHyperSummary modelTrainHyper;
private List<SelectDataSet> modelTrainDataset;
private List<SelectTransferDataSet> modelTrainDataset;
// private List<SelectDataSet> beforeTrainDataset;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelTrainMetrics {
private Integer epoch;
private Long iteration;
private Double loss;
private Double lr;
private Float durationTime;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelValidationMetrics {
private Integer epoch;
private Float aAcc;
private Float mFscore;
private Float mPrecision;
private Float mRecall;
private Float mIou;
private Float mAcc;
private Float changedFscore;
private Float changedPrecision;
private Float changedRecall;
private Float unchangedFscore;
private Float unchangedPrecision;
private Float unchangedRecall;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelTestMetrics {
private String model;
private Long tp;
private Long fp;
private Long fn;
private Float precision;
private Float recall;
private Float f1Score;
private Float accuracy;
private Float iou;
private Long detectionCount;
private Long gtCount;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelBestEpoch {
private Integer epoch;
private Double loss;
private Float f1Score;
private Float precision;
private Float recall;
private Float iou;
private Float accuracy;
}
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ModelFileInfo {
private Boolean fileExistsYn;
private String fileName;
}
}

View File

@@ -40,6 +40,14 @@ public class ModelTrainMngDto {
private String statusCd;
private String trainType;
private String modelNo;
private Long currentAttemptId;
private String requestPath;
private String packingState;
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private Long beforeModelId;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
@@ -59,7 +67,7 @@ public class ModelTrainMngDto {
}
}
public String getStep2StatusNAme() {
public String getStep2StatusName() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
@@ -98,6 +106,10 @@ public class ModelTrainMngDto {
public String getStep2Duration() {
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
}
public String getPackingDuration() {
return formatDuration(this.packingStrtDttm, this.packingEndDttm);
}
}
@Schema(name = "searchReq", description = "모델학습 관리 목록조회 파라미터")
@@ -154,6 +166,17 @@ public class ModelTrainMngDto {
ModelConfig modelConfig;
}
@Schema(name = "addReq", description = "모델학습 관리 등록 파라미터")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class UpdateReq {
private String requestPath;
private String responsePath;
}
@Getter
@Setter
public static class TrainingDataset {
@@ -197,4 +220,111 @@ public class ModelTrainMngDto {
@Schema(description = "메모", example = "메모 입니다.")
private String memo;
}
@Schema(name = "모델학습관리 목록", description = "모델학습관리 목록")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class ListDto {
private Long id;
private UUID uuid;
private String modelVer;
@JsonFormatDttm private ZonedDateTime startDttm;
@JsonFormatDttm private ZonedDateTime step1StrtDttm;
@JsonFormatDttm private ZonedDateTime step1EndDttm;
@JsonFormatDttm private ZonedDateTime step2StrtDttm;
@JsonFormatDttm private ZonedDateTime step2EndDttm;
private String step1Status;
private String step2Status;
private String statusCd;
private String trainType;
private String modelNo;
private Long currentAttemptId;
private String requestPath;
private String packingState;
private ZonedDateTime packingStrtDttm;
private ZonedDateTime packingEndDttm;
private String memo;
private String userNm;
private UUID beforeUuid;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep1StatusName() {
if (this.step1Status == null || this.step1Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep2StatusName() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getTrainTypeName() {
if (this.trainType == null || this.trainType.isBlank()) return null;
try {
return TrainType.valueOf(this.trainType).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
private String formatDuration(ZonedDateTime start, ZonedDateTime end) {
if (start == null || end == null) {
return null;
}
long totalSeconds = Math.abs(Duration.between(start, end).getSeconds());
long hours = totalSeconds / 3600;
long minutes = (totalSeconds % 3600) / 60;
long seconds = totalSeconds % 60;
return String.format("%d시간 %d분 %d초", hours, minutes, seconds);
}
public String getStep1Duration() {
return formatDuration(this.step1StrtDttm, this.step1EndDttm);
}
public String getStep2Duration() {
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
}
public String getPackingDuration() {
return formatDuration(this.packingStrtDttm, this.packingEndDttm);
}
}
@Getter
@Builder
@AllArgsConstructor
public static class ModelProgressStepDto {
private int step;
private String status;
@JsonFormatDttm private ZonedDateTime startTime;
@JsonFormatDttm private ZonedDateTime endTime;
private boolean isError;
}
}

View File

@@ -1,14 +1,21 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.core.ModelTrainDetailCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.util.ArrayList;
@@ -66,11 +73,11 @@ public class ModelTrainDetailService {
Basic modelInfo = modelTrainDetailCoreService.findByModelByUUID(uuid);
// config 정보 조회
ModelConfigDto.Basic configInfo = mngCoreService.findModelConfigByModelId(uuid);
ModelConfigDto.TransferBasic configInfo = mngCoreService.findModelTransferConfigByModelId(uuid);
// 하이파라미터 정보 조회
TransferHyperSummary hyperSummary = modelTrainDetailCoreService.getTransferHyperSummary(uuid);
List<SelectDataSet> dataSets = new ArrayList<>();
List<SelectTransferDataSet> dataSets = new ArrayList<>();
DatasetReq datasetReq = new DatasetReq();
List<Long> datasetIds = new ArrayList<>();
@@ -83,12 +90,37 @@ public class ModelTrainDetailService {
datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals("G1")) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetTransferSelectG1List(modelInfo.getId());
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
dataSets =
mngCoreService.getDatasetTransferSelectG2G3List(
modelInfo.getId(), modelInfo.getModelNo());
}
// DatasetReq beforeDatasetReq = new DatasetReq();
// List<Long> beforeDatasetIds = new ArrayList<>();
// List<SelectDataSet> beforeDataSets = new ArrayList<>();
//
// Long beforeModelId = modelInfo.getBeforeModelId();
// if (beforeModelId != null) {
// Basic beforeInfo = modelTrainDetailCoreService.findByModelBeforeId(beforeModelId);
// List<MappingDataset> beforeDatasets =
// modelTrainDetailCoreService.getByModelMappingDataset(beforeInfo.getUuid());
//
// for (MappingDataset before : beforeDatasets) {
// beforeDatasetIds.add(before.getDatasetId());
// }
// beforeDatasetReq.setIds(beforeDatasetIds);
// beforeDatasetReq.setModelNo(modelInfo.getModelNo());
//
// if (beforeInfo.getModelNo().equals(ModelType.G1.getId())) {
// beforeDataSets = mngCoreService.getDatasetSelectG1List(beforeDatasetReq);
// } else {
// beforeDataSets = mngCoreService.getDatasetSelectG2G3List(beforeDatasetReq);
// }
// }
TransferDetailDto transferDetailDto = new TransferDetailDto();
transferDetailDto.setEtcConfig(configInfo);
transferDetailDto.setModelTrainHyper(hyperSummary);
@@ -96,4 +128,28 @@ public class ModelTrainDetailService {
return transferDetailDto;
}
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainMetricResult(uuid);
}
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelValidationMetricResult(uuid);
}
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
return modelTrainDetailCoreService.getModelTestMetricResult(uuid);
}
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
}
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelTrainDetailCoreService.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelTrainDetailCoreService.findModelTrainProgressInfo(uuid);
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -12,6 +13,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.service.TrainJobService;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -29,6 +31,7 @@ public class ModelTrainMngService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final HyperParamCoreService hyperParamCoreService;
private final TrainJobService trainJobService;
/**
* 모델학습 조회
@@ -36,7 +39,7 @@ public class ModelTrainMngService {
* @param searchReq 검색 조건
* @return 페이징 처리된 모델 목록
*/
public Page<ModelTrainMngDto.Basic> getModelList(SearchReq searchReq) {
public Page<ModelTrainMngDto.ListDto> getModelList(SearchReq searchReq) {
return modelTrainMngCoreService.findByModelList(searchReq);
}
@@ -57,13 +60,13 @@ public class ModelTrainMngService {
* @return
*/
@Transactional
public void createModelTrain(ModelTrainMngDto.AddReq req) {
public UUID createModelTrain(ModelTrainMngDto.AddReq req) {
HyperParam hyperParam = req.getHyperParam();
HyperParamDto.Basic hyper = new HyperParamDto.Basic();
// 전이 학습은 모델 선택 필수
if (req.getTrainType().equals(TrainType.TRANSFER.getId())) {
if (req.getBeforeModelId() != null) {
if (TrainType.TRANSFER.getId().equals(req.getTrainType())) {
if (req.getBeforeModelId() == null) {
throw new CustomApiException("BAD_REQUEST", HttpStatus.BAD_REQUEST, "모델을 선택해 주세요.");
}
}
@@ -76,7 +79,10 @@ public class ModelTrainMngService {
}
// 모델학습 테이블 저장
Long modelId = modelTrainMngCoreService.saveModel(req);
ModelTrainMngDto.Basic modelDto = modelTrainMngCoreService.saveModel(req);
Long modelId = modelDto.getId();
UUID modelUuid = modelDto.getUuid();
// 모델학습 데이터셋 저장
modelTrainMngCoreService.saveModelDataset(modelId, req);
@@ -87,6 +93,10 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
// 데이터셋 임시파일 생성
trainJobService.createTmpFile(modelUuid);
return modelUuid;
}
/**
@@ -106,10 +116,14 @@ public class ModelTrainMngService {
* @return
*/
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals("G1")) {
if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
}
}
public Long findModelStep1InProgressCnt() {
return modelTrainMngCoreService.findModelStep1InProgressCnt();
}
}

View File

@@ -2,6 +2,7 @@ package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.service.BaseCoreService;
import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import com.kamco.cd.training.postgres.repository.log.AuditLogRepository;
import java.time.LocalDate;
import lombok.RequiredArgsConstructor;
@@ -45,6 +46,11 @@ public class AuditLogCoreService
return auditLogRepository.findLogByAccount(searchRange, searchValue);
}
public Page<AuditLogDto.DownloadRes> findLogByAccount(
AuditLogDto.searchReq searchReq, DownloadReq downloadReq) {
return auditLogRepository.findDownloadLog(searchReq, downloadReq);
}
public Page<AuditLogDto.DailyDetail> getLogByDailyResult(
AuditLogDto.searchReq searchRange, LocalDate logDate) {
return auditLogRepository.findLogByDailyResult(searchRange, logDate);

View File

@@ -242,4 +242,12 @@ public class DatasetCoreService
entity.setStatus(LearnDataRegister.COMPLETED.getId());
}
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
datasetObjRepository.insertDatasetValObj(objRegDto);
}
public Long findDatasetByUidExistsCnt(String uid) {
return datasetRepository.findDatasetByUidExistsCnt(uid);
}
}

View File

@@ -1,10 +1,12 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.Basic;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.kamco.cd.training.postgres.repository.hyperparam.HyperParamRepository;
import java.time.ZonedDateTime;
@@ -17,6 +19,7 @@ import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
public class HyperParamCoreService {
private final HyperParamRepository hyperParamRepository;
private final UserUtil userUtil;
@@ -27,11 +30,10 @@ public class HyperParamCoreService {
* @return 등록된 버전명
*/
public Basic createHyperParam(HyperParam createReq) {
String firstVersion = getFirstHyperParamVersion();
String firstVersion = getFirstHyperParamVersion(createReq.getModel());
ModelHyperParamEntity entity = new ModelHyperParamEntity();
entity.setHyperVer(firstVersion);
applyHyperParam(entity, createReq);
// user
@@ -57,7 +59,7 @@ public class HyperParamCoreService {
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
}
applyHyperParam(entity, createReq);
@@ -69,11 +71,112 @@ public class HyperParamCoreService {
return entity.getHyperVer();
}
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// }
// 디폴트면 삭제불가
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository.getHyperParamByType(model).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
return hyperParamRepository.findByHyperVerList(model, req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @param model 모델 타입
* @return ver
*/
public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository
.findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model))
.orElse(model.name() + "_000001");
}
/**
* 하이퍼 파라미터의 버전을 증가시킨다.
*
* @param hyperVer 현재 버전
* @param modelType 모델 타입
* @return 증가된 버전
*/
private String increase(String hyperVer, ModelType modelType) {
String prefix = modelType.name() + "_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%06d", num + 1);
}
private void applyHyperParam(ModelHyperParamEntity entity, HyperParam src) {
ModelType model = src.getModel();
// 하드코딩 모델별로 다른경우 250212 bbn 하드코딩
if (model == ModelType.G3) {
entity.setCropSize("512,512");
} else {
entity.setCropSize("256,256");
}
entity.setCropSize(src.getCropSize());
// Important
entity.setModelType(model); // 20250212 modeltype추가
entity.setBackbone(src.getBackbone());
entity.setInputSize(src.getInputSize());
entity.setCropSize(src.getCropSize());
entity.setBatchSize(src.getBatchSize());
// Data
@@ -110,79 +213,4 @@ public class HyperParamCoreService {
// memo
entity.setMemo(src.getMemo());
}
/**
* 하이퍼파라미터 삭제
*
* @param uuid
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getHyperVer().equals("HPs_0001")) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
entity.setDelYn(true);
entity.setUpdatedUid(userUtil.getId());
entity.setUpdatedDttm(ZonedDateTime.now());
}
/**
* 하이퍼파라미터 최적화 설정값 조회
*
* @return
*/
public HyperParamDto.Basic getInitHyperParam() {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByHyperVer("HPs_0001")
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 상세 조회
*
* @return
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
/**
* 하이퍼파라미터 목록 조회
*
* @param req
* @return
*/
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
return hyperParamRepository.findByHyperVerList(req);
}
/**
* 하이퍼파라미터 버전 조회
*
* @return ver
*/
public String getFirstHyperParamVersion() {
return hyperParamRepository
.findHyperParamVer()
.map(ModelHyperParamEntity::getHyperVer)
.map(this::increase)
.orElse("HPs_0001");
}
private String increase(String hyperVer) {
String prefix = "HPs_";
int num = Integer.parseInt(hyperVer.substring(prefix.length()));
return prefix + String.format("%04d", num + 1);
}
}

View File

@@ -0,0 +1,50 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.time.ZonedDateTime;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTestMetricsJobCoreService {
private final ModelTestMetricsJobRepository modelTestMetricsJobRepository;
@Transactional
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
modelTestMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
}
// Test 로직 시작
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return modelTestMetricsJobRepository.getTestMetricSaveNotYetModelIds();
}
public void insertModelMetricsTest(List<Object[]> batchArgs) {
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
}
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return modelTestMetricsJobRepository.getTestMetricPackingInfo(modelId);
}
public ModelTestFileName findModelTestFileNames(Long modelId) {
return modelTestMetricsJobRepository.findModelTestFileNames(modelId);
}
@Transactional
public void updatePackingStart(Long modelId, ZonedDateTime now) {
modelTestMetricsJobRepository.updatePackingStart(modelId, now);
}
@Transactional
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
modelTestMetricsJobRepository.updatePackingEnd(modelId, now, failSuccState);
}
}

View File

@@ -7,8 +7,14 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDetailRepository;
@@ -51,6 +57,12 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getModelDetailSummary(uuid);
}
/**
* 하이퍼 파리미터 요약정보
*
* @param uuid 모델마스터 uuid
* @return
*/
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}
@@ -77,4 +89,33 @@ public class ModelTrainDetailCoreService {
public ModelConfigDto.Basic findModelConfig(Long modelId) {
return modelConfigRepository.findModelConfigByModelId(modelId).orElse(null);
}
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
return modelDetailRepository.getModelTrainMetricResult(uuid);
}
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
return modelDetailRepository.getModelValidationMetricResult(uuid);
}
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
return modelDetailRepository.getModelTestMetricResult(uuid);
}
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
return modelDetailRepository.getModelTrainBestEpoch(uuid);
}
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return modelDetailRepository.getModelTrainFileInfo(uuid);
}
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
return modelDetailRepository.findModelTrainProgressInfo(uuid);
}
public Basic findByModelBeforeId(Long beforeModelId) {
ModelMasterEntity entity = modelDetailRepository.findByModelBeforeId(beforeModelId);
return entity.toDto();
}
}

View File

@@ -0,0 +1,187 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Log4j2
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class ModelTrainJobCoreService {
private final ModelTrainJobRepository modelTrainJobRepository;
public int findMaxAttemptNo(Long modelId) {
return modelTrainJobRepository.findMaxAttemptNo(modelId);
}
public Optional<ModelTrainJobDto> findLatestByModelId(Long modelId) {
return modelTrainJobRepository.findLatestByModelId(modelId).map(ModelTrainJobEntity::toDto);
}
public Optional<ModelTrainJobDto> findById(Long jobId) {
return modelTrainJobRepository.findById(jobId).map(ModelTrainJobEntity::toDto);
}
/** QUEUED Job 생성 */
@Transactional
public Long createQueuedJob(
Long modelId, int attemptNo, Map<String, Object> paramsJson, ZonedDateTime queuedDttm) {
ModelTrainJobEntity job = new ModelTrainJobEntity();
job.setModelId(modelId);
job.setAttemptNo(attemptNo);
job.setStatusCd("QUEUED");
job.setParamsJson(paramsJson);
job.setQueuedDttm(queuedDttm != null ? queuedDttm : ZonedDateTime.now());
modelTrainJobRepository.save(job);
modelTrainJobRepository.flush();
return job.getId();
}
/** 실행 시작 처리 */
@Transactional
public void markRunning(
Long jobId,
String containerName,
String logPath,
String lockedBy,
Integer totalEpoch,
String jobType) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("RUNNING");
job.setContainerName(containerName);
job.setLogPath(logPath);
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
job.setJobType(jobType);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/**
* 성공 처리
*
* @param jobId
* @param exitCode
*/
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("SUCCESS");
job.setExitCode(exitCode);
job.setFinishedDttm(ZonedDateTime.now());
}
/**
* 실패 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("FAILED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
}
/**
* 중단됨 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markPaused(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED");
job.setExitCode(exitCode);
job.setErrorMessage(errorMessage);
job.setFinishedDttm(ZonedDateTime.now());
log.info("[TRAIN JOB FAIL] jobId={}, modelId={}", jobId, errorMessage);
}
/** 취소 처리 */
@Transactional
public void markCanceled(Long jobId) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setStatusCd("STOPPED");
job.setFinishedDttm(ZonedDateTime.now());
}
@Transactional
public void updateEpoch(String containerName, Integer epoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findByContainerName(containerName)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
job.setCurrentEpoch(epoch);
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
}
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
modelTrainJobRepository.insertModelTestTrainingRun(modelId, jobId, epoch);
}
/**
* 실행중인 학습이 있는지 조회
*
* @return
*/
public List<ModelTrainJobDto> findRunningJobs() {
List<ModelTrainJobEntity> entity = modelTrainJobRepository.findRunningJobs();
if (entity == null || entity.isEmpty()) {
return Collections.emptyList();
}
return entity.stream().map(ModelTrainJobEntity::toDto).toList();
}
}

View File

@@ -0,0 +1,37 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTrainMetricsJobCoreService {
private final ModelTrainMetricsJobRepository modelTrainMetricsJobRepository;
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return modelTrainMetricsJobRepository.getTrainMetricSaveNotYetModelIds();
}
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsTrain(batchArgs);
}
@Transactional
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
modelTrainMetricsJobRepository.updateModelMetricsTrainSaveYn(modelId, stepNo);
}
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
}
@Transactional
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch);
}
}

View File

@@ -8,9 +8,10 @@ import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset;
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
import com.kamco.cd.training.postgres.entity.ModelDatasetEntity;
@@ -23,17 +24,23 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
public class ModelTrainMngCoreService {
private final ModelMngRepository modelMngRepository;
private final ModelDatasetRepository modelDatasetRepository;
private final ModelDatasetMappRepository modelDatasetMapRepository;
@@ -48,9 +55,10 @@ public class ModelTrainMngCoreService {
* @param searchReq 검색 조건
* @return 페이징 처리된 모델 목록
*/
public Page<Basic> findByModelList(ModelTrainMngDto.SearchReq searchReq) {
Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
return entityPage.map(ModelMasterEntity::toDto);
public Page<ListDto> findByModelList(ModelTrainMngDto.SearchReq searchReq) {
// Page<ModelMasterEntity> entityPage = modelMngRepository.findByModels(searchReq);
// return entityPage.map(ModelMasterEntity::toDto);
return modelMngRepository.findByModels(searchReq);
}
/**
@@ -74,13 +82,19 @@ public class ModelTrainMngCoreService {
* @param addReq
* @return
*/
public Long saveModel(ModelTrainMngDto.AddReq addReq) {
public ModelTrainMngDto.Basic saveModel(ModelTrainMngDto.AddReq addReq) {
ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
// 최적화 파라미터는 HPs_0001 사용
// 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity =
hyperParamRepository.getHyperParamByType(modelType).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else {
hyperParamEntity =
@@ -90,6 +104,12 @@ public class ModelTrainMngCoreService {
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND);
}
// 하이퍼 파라미터 사용 횟수 업데이트
hyperParamEntity.setTotalUseCnt(
hyperParamEntity.getTotalUseCnt() == null ? 1 : hyperParamEntity.getTotalUseCnt() + 1);
// 최근 사용일시 업데이트
hyperParamEntity.setLastUsedDttm(ZonedDateTime.now());
String modelVer =
String.join(
@@ -100,19 +120,17 @@ public class ModelTrainMngCoreService {
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
entity.setBeforeModelId(addReq.getBeforeModelId());
if (addReq.getIsStart()) {
entity.setModelStep((short) 1);
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStrtDttm(ZonedDateTime.now());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
} else {
entity.setStatusCd(TrainStatusType.READY.getId());
}
entity.setStatusCd(TrainStatusType.READY.getId());
entity.setStep1State(TrainStatusType.READY.getId());
entity.setCreatedUid(userUtil.getId());
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
return resultEntity.getId();
ModelTrainMngDto.Basic result = new ModelTrainMngDto.Basic();
result.setId(resultEntity.getId());
result.setUuid(resultEntity.getUuid());
return result;
}
/**
@@ -144,6 +162,23 @@ public class ModelTrainMngCoreService {
modelDatasetRepository.save(datasetEntity);
}
/**
* 학습모델 수정
*
* @param modelId
* @param req
*/
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// 임시폴더 UID업데이트
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath());
}
}
/**
* 모델 데이터셋 mapping 테이블 저장
*
@@ -172,7 +207,10 @@ public class ModelTrainMngCoreService {
ModelConfigEntity entity = new ModelConfigEntity();
modelMasterEntity.setId(modelId);
entity.setModel(modelMasterEntity);
entity.setEpochCount(req.getEpochCnt());
entity.setEpochCount(
req.getEpochCnt() < 10
? 10
: req.getEpochCnt()); // 에폭이 10 이하이면 10으로 고정하기. 10 이상 에폭으로 해야 best 에폭 파일이 생성되어 내려옴
entity.setTrainPercent(req.getTrainingCnt());
entity.setValidationPercent(req.getValidationCnt());
entity.setTestPercent(req.getTestCnt());
@@ -213,6 +251,20 @@ public class ModelTrainMngCoreService {
}
}
/**
* uuid로 model id 조회
*
* @param uuid
* @return
*/
public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId();
}
/**
* 모델학습 아이디로 config정보 조회
*
@@ -226,6 +278,13 @@ public class ModelTrainMngCoreService {
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
public ModelConfigDto.TransferBasic findModelTransferConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelTransferConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/**
* 데이터셋 G1 목록
*
@@ -236,6 +295,16 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetSelectG1List(req);
}
/**
* 전이학습 데이터셋 G1 목록
*
* @param modelId 모델 Id
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
return datasetRepository.getDatasetTransferSelectG1List(modelId);
}
/**
* 데이터셋 G2, G3 목록
*
@@ -245,4 +314,286 @@ public class ModelTrainMngCoreService {
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
return datasetRepository.getDatasetSelectG2G3List(req);
}
/**
* 전이학습 데이터셋 G2, G3 목록
*
* @param modelId 모델 Id
* @param modelNo G2, G3
* @return
*/
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
}
/**
* 모델관리 조회
*
* @param id
* @return
*/
public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity =
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto();
}
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
@Transactional
public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId);
// 필요하면 시작시간도 여기서 찍어줌
modelMngRepository.flush();
}
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
@Transactional
public void clearLastError(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null);
modelMngRepository.flush();
}
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
@Transactional
public void markStopped(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId());
}
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
@Transactional
public void markCompleted(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId());
master.setLastError(errorMessage);
master.setUpdatedUid(userUtil.getId());
master.setUpdatedDttm(ZonedDateTime.now());
}
@Transactional
public void markSuccess(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId());
// (선택) 마지막 에러 메시지 비우기
master.setLastError(null);
}
/**
* 학습 실행에 필요한 파라미터 조회
*
* @param modelId
* @return
*/
public TrainRunRequest findTrainRunRequest(Long modelId) {
return modelMngRepository.findTrainRunRequest(modelId);
}
/**
* step1 진행중 처리
*
* @param modelId
* @param jobId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStrtDttm(ZonedDateTime.now());
entity.setStep1StrtDttm(ZonedDateTime.now());
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step2 진행중 처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now());
entity.setStep2State(TrainStatusType.IN_PROGRESS.getId());
entity.setCurrentAttemptId(jobId);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step1 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId());
entity.setStep1EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
/**
* step2 완료처리
*
* @param modelId
*/
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId());
entity.setStep2EndDttm(ZonedDateTime.now());
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
}
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setBestEpoch(epoch);
}
/**
* 데이터셋 uid 조회
*
* @param datasetIds
* @return
*/
public List<String> findDatasetUid(List<Long> datasetIds) {
return datasetRepository.findDatasetUid(datasetIds);
}
public List<Long> findModelDatasetMapp(Long modelId) {
List<Long> datasetUids = new ArrayList<>();
List<ModelDatasetMappEntity> entities = modelDatasetMapRepository.findByModelUid(modelId);
for (ModelDatasetMappEntity entity : entities) {
datasetUids.add(entity.getDatasetUid());
}
return datasetUids;
}
public Long findModelStep1InProgressCnt() {
return modelMngRepository.findModelStep1InProgressCnt();
}
/**
* train 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTrainPath(modelId);
}
/**
* validation 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
return modelDatasetMapRepository.findDatasetValPath(modelId);
}
/**
* test 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTestPath(modelId);
}
}

View File

@@ -5,6 +5,7 @@ import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.CommonCreateEntity;
import jakarta.persistence.*;
import java.util.UUID;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NoArgsConstructor;
@@ -14,6 +15,7 @@ import lombok.NoArgsConstructor;
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@Table(name = "tb_audit_log")
public class AuditLogEntity extends CommonCreateEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "audit_log_uid", nullable = false)
@@ -43,6 +45,12 @@ public class AuditLogEntity extends CommonCreateEntity {
@Column(name = "error_log_uid")
private Long errorLogUid;
@Column(name = "download_uuid")
private UUID downloadUuid;
@Column(name = "login_attempt_id")
private String loginAttemptId;
public AuditLogEntity(
Long userUid,
EventType eventType,
@@ -51,7 +59,9 @@ public class AuditLogEntity extends CommonCreateEntity {
String ipAddress,
String requestUri,
String requestBody,
Long errorLogUid) {
Long errorLogUid,
UUID downloadUuid,
String loginAttemptId) {
this.userUid = userUid;
this.eventType = eventType;
this.eventStatus = eventStatus;
@@ -60,6 +70,31 @@ public class AuditLogEntity extends CommonCreateEntity {
this.requestUri = requestUri;
this.requestBody = requestBody;
this.errorLogUid = errorLogUid;
this.downloadUuid = downloadUuid;
this.loginAttemptId = loginAttemptId;
}
/** 파일 다운로드 이력 생성 */
public static AuditLogEntity forFileDownload(
Long userId,
String requestUri,
String menuUid,
String ip,
int httpStatus,
UUID downloadUuid) {
return new AuditLogEntity(
userId,
EventType.DOWNLOAD, // 이벤트 타입 고정
httpStatus < 400 ? EventStatus.SUCCESS : EventStatus.FAILED, // 성공 여부
menuUid,
ip,
requestUri,
null, // requestBody 없음
null, // errorLogUid 없음
downloadUuid,
null // loginAttemptId 없음
);
}
public AuditLogDto.Basic toDto() {

View File

@@ -0,0 +1,117 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.dataset.dto.DatasetObjDto.Basic;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.UUID;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
import org.locationtech.jts.geom.Geometry;
@Getter
@Setter
@Entity
@Table(name = "tb_dataset_val_obj")
public class DatasetValObjEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "obj_id", nullable = false)
private Long objId;
@NotNull
@Column(name = "dataset_uid", nullable = false)
private Long datasetUid;
@Column(name = "target_yyyy")
private Integer targetYyyy;
@Size(max = 255)
@Column(name = "target_class_cd")
private String targetClassCd;
@Column(name = "compare_yyyy")
private Integer compareYyyy;
@Size(max = 255)
@Column(name = "compare_class_cd")
private String compareClassCd;
@Size(max = 255)
@Column(name = "target_path")
private String targetPath;
@Size(max = 255)
@Column(name = "compare_path")
private String comparePath;
@Size(max = 255)
@Column(name = "label_path")
private String labelPath;
@Size(max = 255)
@Column(name = "geojson_path")
private String geojsonPath;
@Size(max = 255)
@Column(name = "map_sheet_num")
private String mapSheetNum;
@ColumnDefault("now()")
@Column(name = "created_dttm")
private ZonedDateTime createdDttm;
@Column(name = "created_uid")
private Long createdUid;
@ColumnDefault("false")
@Column(name = "deleted")
private Boolean deleted;
@Column(name = "uuid")
private UUID uuid;
@Size(max = 32)
@Column(name = "uid")
private String uid;
@JdbcTypeCode(SqlTypes.JSON)
@Column(name = "geo_jsonb", columnDefinition = "jsonb")
private String geoJsonb;
@Column(name = "file_name")
private String fileName;
@Column(name = "geom", columnDefinition = "geometry")
private Geometry geom;
public Basic toDto() {
return new Basic(
this.objId,
this.datasetUid,
this.targetYyyy,
this.targetClassCd,
this.compareYyyy,
this.compareClassCd,
this.targetPath,
this.comparePath,
this.labelPath,
this.geojsonPath,
this.mapSheetNum,
this.createdDttm,
this.createdUid,
this.deleted,
this.uuid,
this.geoJsonb);
}
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import jakarta.persistence.*;
import jakarta.validation.constraints.NotNull;
@@ -191,10 +192,10 @@ public class ModelHyperParamEntity {
@Column(name = "save_best_rule", nullable = false, length = 10)
private String saveBestRule = "greater";
/** Default: 10 */
/** Default: 1 */
@NotNull
@Column(name = "val_interval", nullable = false)
private Integer valInterval = 10;
private Integer valInterval = 1;
/** Default: 400 */
@NotNull
@@ -302,20 +303,24 @@ public class ModelHyperParamEntity {
@Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt")
private Long m1UseCnt = 0L;
@Column(name = "model_type")
@Enumerated(EnumType.STRING)
private ModelType modelType;
@Column(name = "m2_use_cnt")
private Long m2UseCnt = 0L;
@Column(name = "default_param")
private Boolean isDefault = false;
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@Column(name = "total_use_cnt")
private Integer totalUseCnt = 0;
public HyperParamDto.Basic toDto() {
return new HyperParamDto.Basic(
this.modelType,
this.uuid,
this.hyperVer,
this.createdDttm,
this.lastUsedDttm,
this.totalUseCnt,
// -------------------------
// Important
// -------------------------
@@ -385,6 +390,7 @@ public class ModelHyperParamEntity {
// -------------------------
this.gpuCnt,
this.gpuIds,
this.masterPort);
this.masterPort,
this.isDefault);
}
}

View File

@@ -91,6 +91,36 @@ public class ModelMasterEntity {
@Column(name = "before_model_id")
private Long beforeModelId;
@Column(name = "step1_metric_save_yn")
private Boolean step1MetricSaveYn;
@Column(name = "step2_metric_save_yn")
private Boolean step2MetricSaveYn;
@Column(name = "current_attempt_id")
private Long currentAttemptId;
@Column(name = "last_error")
private String lastError;
@Column(name = "best_epoch")
private Integer bestEpoch;
@Column(name = "request_path")
private String requestPath;
@Column(name = "response_path")
private String responsePath;
@Column(name = "packing_state")
private String packingState;
@Column(name = "packing_strt_dttm")
private ZonedDateTime packingStrtDttm;
@Column(name = "packing_end_dttm")
private ZonedDateTime packingEndDttm;
public ModelTrainMngDto.Basic toDto() {
return new ModelTrainMngDto.Basic(
this.id,
@@ -105,6 +135,12 @@ public class ModelMasterEntity {
this.step2State,
this.statusCd,
this.trainType,
this.modelNo);
this.modelNo,
this.currentAttemptId,
this.requestPath,
this.packingState,
this.packingStrtDttm,
this.packingEndDttm,
this.beforeModelId);
}
}

View File

@@ -19,8 +19,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_test")
public class ModelMatricsTestEntity {
@Table(name = "tb_model_metrics_test")
public class ModelMetricsTestEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@@ -45,9 +45,6 @@ public class ModelMatricsTestEntity {
@Column(name = "fn")
private Long fn;
@Column(name = "tn")
private Long tn;
@Column(name = "precisions")
private Float precisions;
@@ -63,8 +60,11 @@ public class ModelMatricsTestEntity {
@Column(name = "iou")
private Float iou;
@Column(name = "processed_images")
private Long processedImages;
@Column(name = "detection_count")
private Long detectionCount;
@Column(name = "gt_count")
private Long gtCount;
@ColumnDefault("now()")
@Column(name = "created_dttm")

View File

@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_train")
public class ModelMatricsTrainEntity {
@Table(name = "tb_model_metrics_train")
public class ModelMetricsTrainEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)

View File

@@ -18,8 +18,8 @@ import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_matrics_validation")
public class ModelMatricsValidationEntity {
@Table(name = "tb_model_metrics_validation")
public class ModelMetricsValidationEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)

View File

@@ -0,0 +1,42 @@
package com.kamco.cd.training.postgres.entity;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import java.time.OffsetDateTime;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
@Getter
@Setter
@Entity
@Table(name = "tb_model_test_training_run")
public class ModelTestTrainingRunEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "tsr_id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@Column(name = "attempt_no")
private Integer attemptNo;
@Column(name = "job_id")
private Long jobId;
@Column(name = "epoch")
private Integer epoch;
@ColumnDefault("now()")
@Column(name = "created_dttm")
private OffsetDateTime createdDttm;
}

View File

@@ -0,0 +1,105 @@
package com.kamco.cd.training.postgres.entity;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
@Getter
@Setter
@Entity
@Table(name = "tb_model_train_job")
public class ModelTrainJobEntity {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "id", nullable = false)
private Long id;
@NotNull
@Column(name = "model_id", nullable = false)
private Long modelId;
@NotNull
@Column(name = "attempt_no", nullable = false)
private Integer attemptNo;
@Size(max = 30)
@NotNull
@Column(name = "status_cd", nullable = false, length = 30)
private String statusCd;
@NotNull
@Column(name = "params_json", nullable = false)
@JdbcTypeCode(SqlTypes.JSON)
private Map<String, Object> paramsJson;
@Size(max = 200)
@Column(name = "container_name", length = 200)
private String containerName;
@Size(max = 500)
@Column(name = "log_path", length = 500)
private String logPath;
@Column(name = "exit_code")
private Integer exitCode;
@Column(name = "error_message", columnDefinition = "TEXT")
private String errorMessage;
@ColumnDefault("now()")
@Column(name = "queued_dttm")
private ZonedDateTime queuedDttm;
@Column(name = "started_dttm")
private ZonedDateTime startedDttm;
@Column(name = "finished_dttm")
private ZonedDateTime finishedDttm;
@Column(name = "locked_dttm")
private ZonedDateTime lockedDttm;
@Size(max = 100)
@Column(name = "locked_by", length = 100)
private String lockedBy;
@Column(name = "total_epoch")
private Integer totalEpoch;
@Column(name = "current_epoch")
private Integer currentEpoch;
@Column(name = "job_type")
private String jobType;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,
this.modelId,
this.attemptNo,
this.statusCd,
this.exitCode,
this.errorMessage,
this.containerName,
this.paramsJson,
this.queuedDttm,
this.startedDttm,
this.finishedDttm,
this.totalEpoch,
this.currentEpoch);
}
}

View File

@@ -22,4 +22,6 @@ public interface DatasetObjRepositoryCustom {
String getFilePathByUUIDPathType(UUID uuid, String pathType);
void insertDatasetTestObj(DatasetObjRegDto objRegDto);
void insertDatasetValObj(DatasetObjRegDto objRegDto);
}

View File

@@ -97,6 +97,49 @@ public class DatasetObjRepositoryImpl implements DatasetObjRepositoryCustom {
}
}
@Override
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
ObjectMapper objectMapper = new ObjectMapper();
String json;
String geometryJson;
try {
json = objectMapper.writeValueAsString(objRegDto.getGeojson());
geometryJson =
objectMapper.writeValueAsString(
objRegDto.getGeojson().path("features").get(0).path("geometry"));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
try {
em.createNativeQuery(
"""
insert into tb_dataset_val_obj
(dataset_uid, target_yyyy, target_class_cd,
compare_yyyy, compare_class_cd,
target_path, compare_path, label_path, geo_jsonb, map_sheet_num, file_name, geom, geojson_path)
values
(?, ?, ?, ?, ?, ?, ?, ?, cast(? as jsonb), ?, ?, ST_SetSRID(ST_GeomFromGeoJSON(?), 5186), ?)
""")
.setParameter(1, objRegDto.getDatasetUid())
.setParameter(2, objRegDto.getTargetYyyy())
.setParameter(3, objRegDto.getTargetClassCd())
.setParameter(4, objRegDto.getCompareYyyy())
.setParameter(5, objRegDto.getCompareClassCd())
.setParameter(6, objRegDto.getTargetPath())
.setParameter(7, objRegDto.getComparePath())
.setParameter(8, objRegDto.getLabelPath())
.setParameter(9, json)
.setParameter(10, objRegDto.getMapSheetNum())
.setParameter(11, objRegDto.getFileName())
.setParameter(12, geometryJson)
.setParameter(13, objRegDto.getGeojsonPath())
.executeUpdate();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Page<DatasetObjEntity> searchDatasetObjectList(SearchReq searchReq) {
Pageable pageable = searchReq.toPageable();

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.dataset.dto.DatasetDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import java.util.List;
import java.util.Optional;
@@ -17,9 +18,17 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG1List(DatasetReq req);
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId);
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(Long modelId, String modelNo);
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
List<String> findDatasetUid(List<Long> datasetIds);
Long findDatasetByUidExistsCnt(String uid);
}

View File

@@ -1,14 +1,20 @@
package com.kamco.cd.training.postgres.repository.dataset;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SearchReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectTransferDataSet;
import com.kamco.cd.training.postgres.entity.DatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder;
@@ -67,7 +73,11 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
// Count 쿼리 별도 실행 (null safe handling)
long total =
Optional.ofNullable(
queryFactory.select(dataset.count()).from(dataset).where(builder).fetchOne())
queryFactory
.select(dataset.count())
.from(dataset)
.where(builder.and(dataset.deleted.isFalse()))
.fetchOne())
.orElse(0L);
return new PageImpl<>(content, pageable, total);
@@ -138,6 +148,103 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG1List(Long modelId) {
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum(),
// ===== before (join으로) =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("building"))
.then(1)
.otherwise(0)
.sum(),
new CaseBuilder()
.when(beforeObj.targetClassCd.eq("container"))
.then(1)
.otherwise(0)
.sum()))
.from(modelMasterEntity)
// ===== 현재 dataset join =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before 모델 join =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId))
.groupBy(
modelMasterEntity.modelNo,
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
@@ -201,6 +308,116 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.fetch();
}
@Override
public List<SelectTransferDataSet> getDatasetTransferSelectG2G3List(
Long modelId, String modelNo) {
// before join용
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
QModelDatasetMappEntity beforeMapp = new QModelDatasetMappEntity("beforeMapp");
QDatasetEntity beforeDataset = new QDatasetEntity("beforeDataset");
QDatasetObjEntity beforeObj = new QDatasetObjEntity("beforeObj");
BooleanBuilder builder = new BooleanBuilder();
NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> elseCnt =
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> selectedCnt = ModelType.G2.getId().equals(modelNo) ? wasteCnt : elseCnt;
// before도 동일 로직으로 cnt 계산
NumberExpression<Long> beforeWasteCnt =
beforeObj.targetClassCd.when("waste").then(1L).otherwise(0L).sum();
NumberExpression<Long> beforeElseCnt =
new CaseBuilder()
.when(beforeObj.targetClassCd.notIn("building", "container", "waste"))
.then(1L)
.otherwise(0L)
.sum();
NumberExpression<Long> beforeSelectedCnt =
ModelType.G2.getId().equals(modelNo) ? beforeWasteCnt : beforeElseCnt;
return queryFactory
.select(
Projections.constructor(
SelectTransferDataSet.class,
// ===== 현재 =====
modelMasterEntity.modelNo, // modelNo 파라미터 사용 (req.getModelNo() 제거)
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
selectedCnt, // classCount 자리에 들어가는 cnt (Long)
// ===== before =====
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo,
beforeSelectedCnt))
.from(modelMasterEntity)
// ===== 현재 dataset =====
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(dataset)
.on(modelDatasetMappEntity.datasetUid.eq(dataset.id))
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
// ===== before dataset =====
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeMapp)
.on(beforeMapp.modelUid.eq(beforeMaster.id))
.leftJoin(beforeDataset)
.on(beforeMapp.datasetUid.eq(beforeDataset.id))
.leftJoin(beforeObj)
.on(beforeDataset.id.eq(beforeObj.datasetUid))
.where(modelMasterEntity.id.eq(modelId).and(builder))
// sum() 때문에 groupBy 필요
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
beforeMaster.modelNo,
beforeDataset.id,
beforeDataset.uuid,
beforeDataset.dataType,
beforeDataset.title,
beforeDataset.roundNo,
beforeDataset.compareYyyy,
beforeDataset.targetYyyy,
beforeDataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
@Override
public Long getDatasetMaxStage(int compareYyyy, int targetYyyy) {
return queryFactory
@@ -239,7 +456,21 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
return queryFactory
.select(dataset.id)
.from(dataset)
.where(dataset.uid.eq(mngRegDto.getUid()))
.where(dataset.uid.eq(mngRegDto.getUid()), dataset.deleted.isFalse())
.fetchOne();
}
@Override
public List<String> findDatasetUid(List<Long> datasetIds) {
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
}
@Override
public Long findDatasetByUidExistsCnt(String uid) {
return queryFactory
.select(dataset.id.count())
.from(dataset)
.where(dataset.uid.eq(uid), dataset.deleted.isFalse())
.fetchOne();
}
}

View File

@@ -1,7 +1,10 @@
package com.kamco.cd.training.postgres.repository.hyperparam;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
@@ -13,11 +16,41 @@ public interface HyperParamRepositoryCustom {
*
* @return
*/
@Deprecated
Optional<ModelHyperParamEntity> findHyperParamVer();
/**
* 모델 타입별 마지막 버전 조회
*
* @param modelType 모델 타입
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType);
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
/**
* 하이퍼 파라미터 상세조회
*
* @param uuid
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req);
/**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
/**
* 하이퍼 파라미터 모델타입으로 조회
*
* @param modelType
* @return
*/
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
}

View File

@@ -2,12 +2,13 @@ package com.kamco.cd.training.postgres.repository.hyperparam;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.HyperType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId;
@@ -41,6 +42,23 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.orderBy(modelHyperParamEntity.hyperVer.desc())
.limit(1)
.fetchOne());
}
@Override
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
@@ -63,17 +81,22 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
.where(modelHyperParamEntity.uuid.eq(uuid))
.fetchOne());
}
@Override
public Page<HyperParamDto.List> findByHyperVerList(HyperParamDto.SearchReq req) {
public Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req) {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.isFalse());
if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model));
}
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전
builder.and(modelHyperParamEntity.hyperVer.contains(req.getHyperVer()));
@@ -96,26 +119,18 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
}
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query =
queryFactory
.select(
Projections.constructor(
HyperParamDto.List.class,
modelHyperParamEntity.uuid,
modelHyperParamEntity.modelType.as("model"),
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m2UseCnt,
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
modelHyperParamEntity.memo,
modelHyperParamEntity.totalUseCnt))
.from(modelHyperParamEntity)
.where(builder);
@@ -140,8 +155,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
asc
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
case "totalUseCnt" ->
query.orderBy(
asc
? modelHyperParamEntity.totalUseCnt.asc()
: modelHyperParamEntity.totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}
@@ -161,4 +179,17 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
return new PageImpl<>(content, pageable, totalCount);
}
@Override
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
return queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.fetch();
}
}

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.postgres.repository.log;
import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import java.time.LocalDate;
import org.springframework.data.domain.Page;
@@ -15,6 +16,9 @@ public interface AuditLogRepositoryCustom {
Page<AuditLogDto.UserAuditList> findLogByAccount(
AuditLogDto.searchReq searchReq, String searchValue);
Page<AuditLogDto.DownloadRes> findDownloadLog(
AuditLogDto.searchReq searchReq, DownloadReq downloadReq);
Page<AuditLogDto.DailyDetail> findLogByDailyResult(
AuditLogDto.searchReq searchReq, LocalDate logDate);

View File

@@ -6,32 +6,42 @@ import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
import static com.kamco.cd.training.postgres.entity.QMenuEntity.menuEntity;
import com.kamco.cd.training.log.dto.AuditLogDto;
import com.kamco.cd.training.log.dto.AuditLogDto.DownloadReq;
import com.kamco.cd.training.log.dto.AuditLogDto.searchReq;
import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.kamco.cd.training.postgres.entity.QMenuEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.*;
import com.querydsl.jpa.impl.JPAQueryFactory;
import io.micrometer.common.util.StringUtils;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
public class AuditLogRepositoryImpl extends QuerydslRepositorySupport
implements AuditLogRepositoryCustom {
private static final ZoneId ZONE = ZoneId.of("Asia/Seoul");
private final JPAQueryFactory queryFactory;
private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)");
public AuditLogRepositoryImpl(JPAQueryFactory queryFactory) {
super(AuditLogEntity.class);
this.queryFactory = queryFactory;
}
@Override
public Page<AuditLogDto.DailyAuditList> findLogByDaily(
AuditLogDto.searchReq searchReq, LocalDate startDate, LocalDate endDate) {
@@ -87,7 +97,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
.from(auditLogEntity)
.leftJoin(menuEntity)
.on(auditLogEntity.menuUid.eq(menuEntity.menuUid))
.where(menuNameEquals(searchValue))
.where(auditLogEntity.menuUid.ne("SYSTEM"), menuNameEquals(searchValue))
.groupBy(auditLogEntity.menuUid)
.offset(pageable.getOffset())
.limit(pageable.getPageSize())
@@ -128,7 +138,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
.from(auditLogEntity)
.leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id))
.where(loginIdOrUsernameContains(searchValue))
.where(auditLogEntity.userUid.isNotNull(), loginIdOrUsernameContains(searchValue))
.groupBy(auditLogEntity.userUid, memberEntity.employeeNo, memberEntity.name)
.offset(pageable.getOffset())
.limit(pageable.getPageSize())
@@ -147,6 +157,62 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
return new PageImpl<>(foundContent, pageable, countQuery);
}
@Override
public Page<AuditLogDto.DownloadRes> findDownloadLog(
AuditLogDto.searchReq searchReq, DownloadReq req) {
Pageable pageable = searchReq.toPageable();
BooleanBuilder whereBuilder = new BooleanBuilder();
whereBuilder.and(auditLogEntity.eventStatus.ne(EventStatus.valueOf("FAILED")));
whereBuilder.and(auditLogEntity.eventType.eq(EventType.valueOf("DOWNLOAD")));
// if (req.getMenuId() != null && !req.getMenuId().isEmpty()) {
// whereBuilder.and(auditLogEntity.menuUid.eq(req.getMenuId()));
// }
if (req.getUuid() != null) {
whereBuilder.and(auditLogEntity.requestUri.contains(req.getRequestUri()));
whereBuilder.and(auditLogEntity.downloadUuid.eq(req.getUuid()));
}
if (req.getSearchValue() != null && !req.getSearchValue().isEmpty()) {
whereBuilder.and(
memberEntity
.name
.contains(req.getSearchValue())
.or(memberEntity.employeeNo.contains(req.getSearchValue())));
}
List<AuditLogDto.DownloadRes> foundContent =
queryFactory
.select(
Projections.constructor(
AuditLogDto.DownloadRes.class,
memberEntity.name,
memberEntity.employeeNo,
auditLogEntity.createdDate.as("downloadDttm")))
.from(auditLogEntity)
.leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id))
.where(whereBuilder, createdDateBetween(req.getStartDate(), req.getEndDate()))
.offset(pageable.getOffset())
.limit(pageable.getPageSize())
.orderBy(auditLogEntity.createdDate.desc())
.fetch();
Long countQuery =
queryFactory
.select(auditLogEntity.userUid.countDistinct())
.from(auditLogEntity)
.leftJoin(memberEntity)
.on(auditLogEntity.userUid.eq(memberEntity.id))
.where(whereBuilder, createdDateBetween(req.getStartDate(), req.getEndDate()))
.fetchOne();
return new PageImpl<>(foundContent, pageable, countQuery);
}
@Override
public Page<AuditLogDto.DailyDetail> findLogByDailyResult(
AuditLogDto.searchReq searchReq, LocalDate logDate) {
@@ -176,6 +242,9 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
memberEntity.employeeNo.as("loginId"),
menuEntity.menuNm.as("menuName"),
auditLogEntity.eventType.as("eventType"),
Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"),
Projections.constructor(
AuditLogDto.LogDetail.class,
Expressions.constant("한국자산관리공사"), // serviceName
@@ -184,7 +253,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
menuEntity.menuUrl.as("menuUrl"),
menuEntity.description.as("menuDescription"),
menuEntity.menuOrder.as("sortOrder"),
menuEntity.isUse.as("used"))))
menuEntity.isUse.as("used")))) // TODO
.from(auditLogEntity)
.leftJoin(menuEntity)
.on(auditLogEntity.menuUid.eq(menuEntity.menuUid))
@@ -238,8 +307,8 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
AuditLogDto.MenuDetail.class,
auditLogEntity.id.as("logId"),
Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate)
.as("logDateTime"), // ??
"to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"),
memberEntity.name.as("userName"),
memberEntity.employeeNo.as("loginId"),
auditLogEntity.eventType.as("eventType"),
@@ -305,7 +374,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
AuditLogDto.UserDetail.class,
auditLogEntity.id.as("logId"),
Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate)
"to_char({0}, 'YYYY-MM-DD HH24:MI')", auditLogEntity.createdDate)
.as("logDateTime"),
menuEntity.menuNm.as("menuName"),
auditLogEntity.eventType.as("eventType"),
@@ -349,12 +418,23 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
if (Objects.isNull(startDate) || Objects.isNull(endDate)) {
return null;
}
LocalDateTime startDateTime = startDate.atStartOfDay();
LocalDateTime endDateTime = endDate.plusDays(1).atStartOfDay();
ZoneId zoneId = ZoneId.of("Asia/Seoul");
ZonedDateTime startDateTime = startDate.atStartOfDay(zoneId);
ZonedDateTime endDateTime = endDate.plusDays(1).atStartOfDay(zoneId);
return auditLogEntity
.createdDate
.goe(ZonedDateTime.from(startDateTime))
.and(auditLogEntity.createdDate.lt(ZonedDateTime.from(endDateTime)));
.goe(startDateTime)
.and(auditLogEntity.createdDate.lt(endDateTime));
}
private BooleanExpression createdDateBetween(LocalDate startDate, LocalDate endDate) {
if (startDate == null || endDate == null) {
return null;
}
ZonedDateTime start = startDate.atStartOfDay(ZONE);
ZonedDateTime endExclusive = endDate.plusDays(1).atStartOfDay(ZONE);
return auditLogEntity.createdDate.goe(start).and(auditLogEntity.createdDate.lt(endExclusive));
}
private BooleanExpression menuNameEquals(String searchValue) {
@@ -393,11 +473,11 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
}
private BooleanExpression eventEndedAtEqDate(LocalDate logDate) {
StringExpression eventEndedDate =
Expressions.stringTemplate("to_char({0}, 'YYYY-MM-DD')", auditLogEntity.createdDate);
LocalDateTime comparisonDate = logDate.atStartOfDay();
ZoneId zoneId = ZoneId.of("Asia/Seoul");
ZonedDateTime start = logDate.atStartOfDay(zoneId);
ZonedDateTime end = logDate.plusDays(1).atStartOfDay(zoneId);
return eventEndedDate.eq(comparisonDate.toString());
return auditLogEntity.createdDate.goe(start).and(auditLogEntity.createdDate.lt(end));
}
private BooleanExpression menuUidEq(String menuUid) {
@@ -410,7 +490,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> readCount() {
return new CaseBuilder()
.when(auditLogEntity.eventType.eq(EventType.READ))
.when(auditLogEntity.eventType.in(EventType.LIST, EventType.DETAIL))
.then(1)
.otherwise(0)
.sum();
@@ -418,7 +498,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> cudCount() {
return new CaseBuilder()
.when(auditLogEntity.eventType.in(EventType.CREATE, EventType.UPDATE, EventType.DELETE))
.when(auditLogEntity.eventType.in(EventType.ADDED, EventType.MODIFIED, EventType.REMOVE))
.then(1)
.otherwise(0)
.sum();
@@ -426,7 +506,7 @@ public class AuditLogRepositoryImpl implements AuditLogRepositoryCustom {
private NumberExpression<Integer> printCount() {
return new CaseBuilder()
.when(auditLogEntity.eventType.eq(EventType.PRINT))
.when(auditLogEntity.eventType.eq(EventType.OTHER))
.then(1)
.otherwise(0)
.sum();

View File

@@ -8,29 +8,35 @@ import static com.kamco.cd.training.postgres.entity.QMenuEntity.menuEntity;
import com.kamco.cd.training.log.dto.ErrorLogDto;
import com.kamco.cd.training.log.dto.EventStatus;
import com.kamco.cd.training.log.dto.EventType;
import com.kamco.cd.training.postgres.entity.AuditLogEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.BooleanExpression;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.StringExpression;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom {
public class ErrorLogRepositoryImpl extends QuerydslRepositorySupport
implements ErrorLogRepositoryCustom {
private final JPAQueryFactory queryFactory;
private final StringExpression NULL_STRING = Expressions.stringTemplate("cast(null as text)");
public ErrorLogRepositoryImpl(JPAQueryFactory queryFactory) {
super(AuditLogEntity.class);
this.queryFactory = queryFactory;
}
@Override
public Page<ErrorLogDto.Basic> findLogByError(ErrorLogDto.ErrorSearchReq searchReq) {
Pageable pageable = searchReq.toPageable();
@@ -52,7 +58,7 @@ public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom {
errorLogEntity.errorMessage.as("errorMessage"),
errorLogEntity.stackTrace.as("errorDetail"),
Expressions.stringTemplate(
"to_char({0}, 'YYYY-MM-DD')", errorLogEntity.createdDate)))
"to_char({0}, 'YYYY-MM-DD HH24:MI:SS.FF3')", errorLogEntity.createdDate)))
.from(errorLogEntity)
.leftJoin(auditLogEntity)
.on(errorLogEntity.id.eq(auditLogEntity.errorLogUid))
@@ -94,12 +100,14 @@ public class ErrorLogRepositoryImpl implements ErrorLogRepositoryCustom {
if (Objects.isNull(startDate) || Objects.isNull(endDate)) {
return null;
}
LocalDateTime startDateTime = startDate.atStartOfDay();
LocalDateTime endDateTime = endDate.plusDays(1).atStartOfDay();
ZoneId zoneId = ZoneId.of("Asia/Seoul");
ZonedDateTime startDateTime = startDate.atStartOfDay(zoneId);
ZonedDateTime endDateTime = endDate.plusDays(1).atStartOfDay(zoneId);
return auditLogEntity
.createdDate
.goe(ZonedDateTime.from(startDateTime))
.and(auditLogEntity.createdDate.lt(ZonedDateTime.from(endDateTime)));
.goe(startDateTime)
.and(auditLogEntity.createdDate.lt(endDateTime));
}
private BooleanExpression eventStatusEqFailed() {

View File

@@ -5,4 +5,6 @@ import java.util.Optional;
public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
Optional<ModelConfigDto.TransferBasic> findModelTransferConfigByModelId(Long modelId);
}

View File

@@ -1,8 +1,12 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.kamco.cd.training.model.dto.ModelConfigDto.TransferBasic;
import com.kamco.cd.training.postgres.entity.QModelConfigEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional;
@@ -34,4 +38,44 @@ public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
.where(modelConfigEntity.model.id.eq(modelId))
.fetchOne());
}
@Override
public Optional<TransferBasic> findModelTransferConfigByModelId(Long modelId) {
QModelConfigEntity beforeConfig = new QModelConfigEntity("beforeConfig");
QModelMasterEntity beforeMaster = new QModelMasterEntity("beforeMaster");
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
TransferBasic.class,
// ===== 현재 =====
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo,
// ===== before =====
beforeConfig.id,
beforeConfig.model.id,
beforeConfig.epochCount,
beforeConfig.trainPercent,
beforeConfig.validationPercent,
beforeConfig.testPercent,
beforeConfig.memo))
.from(modelConfigEntity)
.innerJoin(modelConfigEntity.model, modelMasterEntity)
// before 모델 조인
.leftJoin(beforeMaster)
.on(beforeMaster.id.eq(modelMasterEntity.beforeModelId))
.leftJoin(beforeConfig)
.on(beforeConfig.model.id.eq(beforeMaster.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne());
}
}

View File

@@ -6,4 +6,5 @@ import org.springframework.stereotype.Repository;
@Repository
public interface ModelDatasetMappRepository
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId> {}
extends JpaRepository<ModelDatasetMappEntity, ModelDatasetMappEntity.ModelDatasetMappId>,
ModelDatasetMappRepositoryCustom {}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.util.List;
public interface ModelDatasetMappRepositoryCustom {
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId);
List<ModelTrainLinkDto> findDatasetValPath(Long modelId);
List<ModelTrainLinkDto> findDatasetTestPath(Long modelId);
}

View File

@@ -0,0 +1,165 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositoryCustom {
private final JPAQueryFactory queryFactory;
@Override
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
return queryFactory
.select(modelDatasetMappEntity)
.from(modelDatasetMappEntity)
.where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetObjEntity.targetClassCd,
datasetObjEntity.comparePath,
datasetObjEntity.targetPath,
datasetObjEntity.labelPath,
datasetObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetObjEntity)
.on(
datasetObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetValObjEntity.targetClassCd,
datasetValObjEntity.comparePath,
datasetValObjEntity.targetPath,
datasetValObjEntity.labelPath,
datasetValObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetValObjEntity)
.on(
datasetValObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetTestObjEntity.targetClassCd,
datasetTestObjEntity.comparePath,
datasetTestObjEntity.targetPath,
datasetTestObjEntity.labelPath,
datasetTestObjEntity.geojsonPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetTestObjEntity)
.on(
datasetTestObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
}

View File

@@ -3,7 +3,13 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List;
import java.util.Optional;
@@ -22,4 +28,18 @@ public interface ModelDetailRepositoryCustom {
List<MappingDataset> getByModelMappingDataset(UUID uuid);
ModelMasterEntity findByModelByUUID(UUID uuid);
List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid);
List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid);
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
ModelFileInfo getModelTrainFileInfo(UUID uuid);
List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid);
ModelMasterEntity findByModelBeforeId(Long beforeModelId);
}

View File

@@ -5,23 +5,37 @@ import static com.kamco.cd.training.postgres.entity.QModelDatasetEntity.modelDat
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsValidationEntity.modelMetricsValidationEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferHyperSummary;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ModelProgressStepDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelHyperParamEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Repository;
@Slf4j
@Repository
@RequiredArgsConstructor
public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
@@ -46,6 +60,13 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
@Override
public DetailSummary getModelDetailSummary(UUID uuid) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
return queryFactory
.select(
Projections.constructor(
@@ -57,7 +78,8 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
modelMasterEntity.step1StrtDttm,
modelMasterEntity.step2EndDttm,
modelMasterEntity.statusCd,
modelMasterEntity.trainType))
modelMasterEntity.trainType,
beforeModelUuid))
.from(modelMasterEntity)
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
@@ -154,4 +176,200 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<ModelTrainMetrics> getModelTrainMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelTrainMetrics.class,
modelMetricsTrainEntity.epoch,
modelMetricsTrainEntity.iteration,
modelMetricsTrainEntity.loss,
modelMetricsTrainEntity.lr,
modelMetricsTrainEntity.durationTime))
.from(modelMetricsTrainEntity)
.where(modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public List<ModelValidationMetrics> getModelValidationMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelValidationMetrics.class,
modelMetricsValidationEntity.epoch,
modelMetricsValidationEntity.aAcc,
modelMetricsValidationEntity.mFscore,
modelMetricsValidationEntity.mPrecision,
modelMetricsValidationEntity.mRecall,
modelMetricsValidationEntity.mIou,
modelMetricsValidationEntity.mAcc,
modelMetricsValidationEntity.changedFscore,
modelMetricsValidationEntity.changedPrecision,
modelMetricsValidationEntity.changedRecall,
modelMetricsValidationEntity.unchangedFscore,
modelMetricsValidationEntity.unchangedPrecision,
modelMetricsValidationEntity.unchangedRecall))
.from(modelMetricsValidationEntity)
.where(modelMetricsValidationEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public List<ModelTestMetrics> getModelTestMetricResult(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return List.of();
}
return queryFactory
.select(
Projections.constructor(
ModelTestMetrics.class,
modelMetricsTestEntity.model1,
modelMetricsTestEntity.tp,
modelMetricsTestEntity.fp,
modelMetricsTestEntity.fn,
modelMetricsTestEntity.precisions,
modelMetricsTestEntity.recall,
modelMetricsTestEntity.f1Score,
modelMetricsTestEntity.accuracy,
modelMetricsTestEntity.iou,
modelMetricsTestEntity.detectionCount,
modelMetricsTestEntity.gtCount))
.from(modelMetricsTestEntity)
.where(modelMetricsTestEntity.model.id.eq(modelMasterEntity.getId()))
.fetch();
}
@Override
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
ModelMasterEntity modelMasterEntity = findByModelByUUID(uuid);
if (modelMasterEntity == null) {
return null;
}
return queryFactory
.select(
Projections.constructor(
ModelBestEpoch.class,
modelMetricsTrainEntity.epoch,
modelMetricsTrainEntity.loss,
modelMetricsValidationEntity.mFscore,
modelMetricsValidationEntity.mPrecision,
modelMetricsValidationEntity.mRecall,
modelMetricsValidationEntity.mIou,
modelMetricsValidationEntity.mAcc))
.from(modelMetricsTrainEntity)
.leftJoin(modelMetricsValidationEntity)
.on(
modelMetricsTrainEntity.model.eq(modelMetricsValidationEntity.model),
modelMetricsTrainEntity.epoch.eq(modelMetricsValidationEntity.epoch))
.where(
modelMetricsTrainEntity.model.id.eq(modelMasterEntity.getId()),
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
.fetchOne();
}
@Override
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
return queryFactory
.select(
Projections.constructor(
ModelFileInfo.class,
modelMasterEntity
.packingState
.eq(TrainStatusType.COMPLETED.getId())
.coalesce(false),
modelMasterEntity.modelVer))
.from(modelMasterEntity)
.where(modelMasterEntity.uuid.eq(uuid))
.fetchOne();
}
@Override
public List<ModelProgressStepDto> findModelTrainProgressInfo(UUID uuid) {
ModelMasterEntity entity = findByModelByUUID(uuid);
if (entity == null) {
return List.of();
}
List<ModelProgressStepDto> steps = new ArrayList<>();
// 0단계 : 대기 상태
steps.add(
ModelProgressStepDto.builder()
.step(0)
.status(TrainStatusType.READY.getId())
.startTime(entity.getCreatedDttm())
.endTime(null)
.isError(false)
.build());
// 1단계 : Train/Validation 실행
boolean step1Active =
entity.getStep1StrtDttm() != null
&& !TrainStatusType.READY.getId().equals(entity.getStep1State());
if (step1Active) {
steps.add(
ModelProgressStepDto.builder()
.step(1)
.status(entity.getStep1State())
.startTime(entity.getStep1StrtDttm())
.endTime(entity.getStep1EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep1State()))
.build());
}
// 2단계 : Test 실행
boolean step2Done = entity.getStep2State() != null;
if (step2Done) {
steps.add(
ModelProgressStepDto.builder()
.step(2)
.status(entity.getStep2State())
.startTime(entity.getStep2StrtDttm())
.endTime(entity.getStep2EndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getStep2State()))
.build());
}
// 3단계 : 패키징
boolean step3Done = entity.getPackingState() != null;
if (step3Done) {
steps.add(
ModelProgressStepDto.builder()
.step(3)
.status(entity.getPackingState())
.startTime(entity.getPackingStrtDttm())
.endTime(entity.getPackingEndDttm())
.isError(TrainStatusType.ERROR.getId().equals(entity.getPackingState()))
.build());
}
return steps;
}
@Override
public ModelMasterEntity findByModelBeforeId(Long beforeModelId) {
return queryFactory
.selectFrom(modelMasterEntity)
.where(modelMasterEntity.id.eq(beforeModelId))
.fetchOne();
}
}

View File

@@ -1,7 +1,9 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.util.Optional;
import java.util.UUID;
import org.springframework.data.domain.Page;
@@ -14,9 +16,13 @@ public interface ModelMngRepositoryCustom {
* @param searchReq
* @return
*/
Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq searchReq);
Page<ListDto> findByModels(ModelTrainMngDto.SearchReq searchReq);
Optional<ModelMasterEntity> findByUuid(UUID uuid);
Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn);
TrainRunRequest findTrainRunRequest(Long modelId);
Long findModelStep1InProgressCnt();
}

View File

@@ -1,10 +1,20 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.ListDto;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import com.kamco.cd.training.postgres.entity.QModelMasterEntity;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import java.util.Optional;
@@ -28,21 +38,62 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
* @return
*/
@Override
public Page<ModelMasterEntity> findByModels(ModelTrainMngDto.SearchReq req) {
public Page<ListDto> findByModels(ModelTrainMngDto.SearchReq req) {
QModelMasterEntity beforeModel = new QModelMasterEntity("beforeModel"); // alias
Expression<UUID> beforeModelUuid =
com.querydsl.jpa.JPAExpressions.select(beforeModel.uuid)
.from(beforeModel)
.where(beforeModel.id.eq(modelMasterEntity.beforeModelId));
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
if (req.getStatus() != null && !req.getStatus().isEmpty()) {
builder.and(modelMasterEntity.statusCd.eq(req.getStatus()));
builder.and(
modelMasterEntity
.step1State
.eq(req.getStatus())
.or(modelMasterEntity.step2State.eq(req.getStatus())));
}
if (req.getModelNo() != null && !req.getModelNo().isEmpty()) {
builder.and(modelMasterEntity.modelNo.eq(req.getModelNo()));
}
List<ModelMasterEntity> content =
builder.and(modelMasterEntity.delYn.isFalse());
List<ListDto> content =
queryFactory
.selectFrom(modelMasterEntity)
.select(
Projections.constructor(
ListDto.class,
modelMasterEntity.id,
modelMasterEntity.uuid,
modelMasterEntity.modelVer,
modelMasterEntity.strtDttm,
modelMasterEntity.step1StrtDttm,
modelMasterEntity.step1EndDttm,
modelMasterEntity.step2StrtDttm,
modelMasterEntity.step2EndDttm,
modelMasterEntity.step1State,
modelMasterEntity.step2State,
modelMasterEntity.statusCd,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelMasterEntity.currentAttemptId,
modelMasterEntity.requestPath,
modelMasterEntity.packingState,
modelMasterEntity.packingStrtDttm,
modelMasterEntity.packingEndDttm,
modelConfigEntity.memo,
memberEntity.name,
beforeModelUuid))
.from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))
.leftJoin(memberEntity)
.on(modelMasterEntity.createdUid.eq(memberEntity.id))
.where(builder)
.offset(pageable.getOffset())
.limit(pageable.getPageSize())
@@ -54,6 +105,10 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
queryFactory
.select(modelMasterEntity.count())
.from(modelMasterEntity)
.innerJoin(modelConfigEntity)
.on(modelMasterEntity.id.eq(modelConfigEntity.model.id))
.leftJoin(memberEntity)
.on(modelMasterEntity.createdUid.eq(memberEntity.id))
.where(builder)
.fetchOne();
@@ -82,4 +137,75 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
public Optional<ModelMasterEntity> findFirstByStatusCdAndDelYn(String statusCd, Boolean delYn) {
return Optional.empty();
}
@Override
public TrainRunRequest findTrainRunRequest(Long modelId) {
return queryFactory
.select(
Projections.constructor(
TrainRunRequest.class,
modelMasterEntity.requestPath, // datasetFolder
modelMasterEntity.uuid, // outputFolder
modelHyperParamEntity.inputSize,
modelHyperParamEntity.cropSize,
modelHyperParamEntity.batchSize,
modelHyperParamEntity.gpuIds,
modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone,
modelConfigEntity.epochCount,
modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers,
modelHyperParamEntity.trainShuffle,
modelHyperParamEntity.trainPersistent,
modelHyperParamEntity.valPersistent,
modelHyperParamEntity.dropPathRate,
modelHyperParamEntity.frozenStages,
modelHyperParamEntity.neckPolicy,
modelHyperParamEntity.classWeight,
modelHyperParamEntity.decoderChannels,
modelHyperParamEntity.weightDecay,
modelHyperParamEntity.layerDecayRate,
modelHyperParamEntity.ignoreIndex,
modelHyperParamEntity.ddpFindUnusedParams,
modelHyperParamEntity.numLayers,
modelHyperParamEntity.metrics,
modelHyperParamEntity.saveBest,
modelHyperParamEntity.saveBestRule,
modelHyperParamEntity.valInterval,
modelHyperParamEntity.logInterval,
modelHyperParamEntity.visInterval,
modelHyperParamEntity.rotProb,
modelHyperParamEntity.rotDegree,
modelHyperParamEntity.flipProb,
modelHyperParamEntity.exchangeProb,
modelHyperParamEntity.brightnessDelta,
modelHyperParamEntity.contrastRange,
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class),
modelHyperParamEntity.uuid))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
.leftJoin(modelConfigEntity)
.on(modelConfigEntity.model.id.eq(modelMasterEntity.id))
.where(modelMasterEntity.id.eq(modelId))
.fetchOne();
}
@Override
public Long findModelStep1InProgressCnt() {
return queryFactory
.select(modelMasterEntity.id.count())
.from(modelMasterEntity)
.where(
modelMasterEntity
.step1State
.eq(TrainStatusType.IN_PROGRESS.getId())
.or(modelMasterEntity.step2State.eq(TrainStatusType.IN_PROGRESS.getId())))
.fetchOne();
}
}

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelTestMetricsJobRepository
extends JpaRepository<ModelMetricsTestEntity, Long>, ModelTestMetricsJobRepositoryCustom {}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.time.ZonedDateTime;
import java.util.List;
public interface ModelTestMetricsJobRepositoryCustom {
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
void insertModelMetricsTest(List<Object[]> batchArgs);
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
ModelTestFileName findModelTestFileNames(Long modelId);
void updatePackingStart(Long modelId, ZonedDateTime now);
void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState);
}

View File

@@ -0,0 +1,171 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.Properties;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZonedDateTime;
import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.jdbc.core.JdbcTemplate;
public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
implements ModelTestMetricsJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
private final JdbcTemplate jdbcTemplate;
public ModelTestMetricsJobRepositoryImpl(
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
super(ModelMetricsTestEntity.class);
this.queryFactory = queryFactory;
this.jdbcTemplate = jdbcTemplate;
}
@Override
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
queryFactory
.update(modelMasterEntity)
.set(
stepNo.equals("step1")
? modelMasterEntity.step1MetricSaveYn
: modelMasterEntity.step2MetricSaveYn,
true)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public List<ResponsePathDto> getTestMetricSaveNotYetModelIds() {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step2EndDttm.isNotNull(),
modelMasterEntity.step2State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step2MetricSaveYn
.isNull()
.or(modelMasterEntity.step2MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTest(List<Object[]> batchArgs) {
// AS-IS
// String sql =
// """
// insert into tb_model_metrics_test
// (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
// detection_count, gt_count
// )
// values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
// """;
//
// jdbcTemplate.batchUpdate(sql, batchArgs);
// TO-BE: modelId, model(best_fscore_10) 같은 데이터가 있으면 update, 없으면 insert
String updateSql =
"""
UPDATE tb_model_metrics_test
SET tp=?, fp=?, fn=?, precisions=?, recall=?, f1_score=?, accuracy=?, iou=?,
detection_count=?, gt_count=?
WHERE model_id=? AND model=?
""";
String insertSql =
"""
INSERT INTO tb_model_metrics_test
(model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
detection_count, gt_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
// row 단위 처리 (batch 안에서 upsert)
for (Object[] row : batchArgs) {
// row 순서: (model_id, model, tp, fp, fn, precisions, recall, f1_score, accuracy, iou,
// detection_count, gt_count)
int updated =
jdbcTemplate.update(
updateSql, row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10],
row[11], row[0], row[1]);
if (updated == 0) {
jdbcTemplate.update(insertSql, row);
}
}
}
@Override
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelMetricJsonDto.class,
modelMasterEntity.modelNo,
modelMasterEntity.modelVer,
Projections.constructor(
Properties.class,
modelMetricsTestEntity.f1Score,
modelMetricsTestEntity.precisions,
modelMetricsTestEntity.recall,
modelMetricsTestEntity.iou,
modelMetricsTrainEntity.loss)))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.innerJoin(modelMetricsTrainEntity)
.on(
modelMetricsTestEntity.model.eq(modelMetricsTrainEntity.model),
modelMasterEntity.bestEpoch.eq(modelMetricsTrainEntity.epoch))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.orderBy(modelMetricsTestEntity.createdDttm.desc())
.fetchFirst();
}
@Override
public ModelTestFileName findModelTestFileNames(Long modelId) {
return queryFactory
.select(
Projections.constructor(
ModelTestFileName.class, modelMetricsTestEntity.model1, modelMasterEntity.modelVer))
.from(modelMetricsTestEntity)
.innerJoin(modelMasterEntity)
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
.where(modelMetricsTestEntity.model.id.eq(modelId))
.fetchOne();
}
@Override
public void updatePackingStart(Long modelId, ZonedDateTime now) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.packingStrtDttm, ZonedDateTime.now())
.set(modelMasterEntity.packingState, TrainStatusType.READY.getId())
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public void updatePackingEnd(Long modelId, ZonedDateTime now, String failSuccState) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.packingEndDttm, ZonedDateTime.now())
.set(modelMasterEntity.packingState, failSuccState)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
}

View File

@@ -0,0 +1,7 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import org.springframework.data.jpa.repository.JpaRepository;
public interface ModelTrainJobRepository
extends JpaRepository<ModelTrainJobEntity, Long>, ModelTrainJobRepositoryCustom {}

View File

@@ -0,0 +1,17 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import java.util.List;
import java.util.Optional;
public interface ModelTrainJobRepositoryCustom {
int findMaxAttemptNo(Long modelId);
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch);
List<ModelTrainJobEntity> findRunningJobs();
}

View File

@@ -0,0 +1,99 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelTestTrainingRunEntity.modelTestTrainingRunEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.JobStatusType;
import com.kamco.cd.training.common.enums.JobType;
import com.kamco.cd.training.postgres.entity.ModelTrainJobEntity;
import com.kamco.cd.training.postgres.entity.QModelTrainJobEntity;
import com.querydsl.jpa.impl.JPAQueryFactory;
import jakarta.persistence.EntityManager;
import java.util.List;
import java.util.Optional;
import org.springframework.stereotype.Repository;
@Repository
public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
public ModelTrainJobRepositoryImpl(EntityManager em) {
this.queryFactory = new JPAQueryFactory(em);
}
/** modelId의 attempt_no 최대값. (없으면 0) */
@Override
public int findMaxAttemptNo(Long modelId) {
QModelTrainJobEntity j = modelTrainJobEntity;
Integer max =
queryFactory.select(j.attemptNo.max()).from(j).where(j.modelId.eq(modelId)).fetchOne();
return max != null ? max : 0;
}
/**
* modelId의 최신 job 1건 (보통 id desc / queuedDttm desc 등) - attemptNo 기준으로도 가능하지만, 여기선 id desc가 가장
* 단순.
*/
@Override
public Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId) {
QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory.selectFrom(j).where(j.modelId.eq(modelId)).orderBy(j.id.desc()).fetchFirst();
return Optional.ofNullable(job);
}
@Override
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
QModelTrainJobEntity j = modelTrainJobEntity;
ModelTrainJobEntity job =
queryFactory
.selectFrom(j)
.where(j.containerName.eq(containerName))
.orderBy(j.id.desc())
.fetchFirst();
return Optional.ofNullable(job);
}
@Override
public void insertModelTestTrainingRun(Long modelId, Long jobId, int epoch) {
Integer maxAttemptNo =
queryFactory
.select(modelTestTrainingRunEntity.attemptNo.max().coalesce(0))
.from(modelTestTrainingRunEntity)
.where(modelTestTrainingRunEntity.modelId.eq(modelId))
.fetchOne();
int nextAttemptNo = (maxAttemptNo == null ? 1 : maxAttemptNo + 1);
queryFactory
.insert(modelTestTrainingRunEntity)
.columns(
modelTestTrainingRunEntity.modelId,
modelTestTrainingRunEntity.attemptNo,
modelTestTrainingRunEntity.jobId,
modelTestTrainingRunEntity.epoch)
.values(modelId, nextAttemptNo, jobId, epoch)
.execute();
}
@Override
public List<ModelTrainJobEntity> findRunningJobs() {
return queryFactory
.select(modelTrainJobEntity)
.from(modelTrainJobEntity)
.where(
modelTrainJobEntity
.statusCd
.eq(JobStatusType.RUNNING.getId())
.and(modelTrainJobEntity.jobType.eq(JobType.TRAIN.getId())))
.orderBy(modelTrainJobEntity.id.desc())
.fetch();
}
}

View File

@@ -0,0 +1,9 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface ModelTrainMetricsJobRepository
extends JpaRepository<ModelMetricsTrainEntity, Long>, ModelTrainMetricsJobRepositoryCustom {}

View File

@@ -0,0 +1,17 @@
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.util.List;
public interface ModelTrainMetricsJobRepositoryCustom {
List<ResponsePathDto> getTrainMetricSaveNotYetModelIds();
void insertModelMetricsTrain(List<Object[]> batchArgs);
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
void insertModelMetricsValidation(List<Object[]> batchArgs);
void updateModelSelectedBestEpoch(Long modelId, Integer epoch);
}

View File

@@ -0,0 +1,94 @@
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import org.springframework.data.jpa.repository.support.QuerydslRepositorySupport;
import org.springframework.jdbc.core.JdbcTemplate;
public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySupport
implements ModelTrainMetricsJobRepositoryCustom {
private final JPAQueryFactory queryFactory;
private final JdbcTemplate jdbcTemplate;
public ModelTrainMetricsJobRepositoryImpl(
JPAQueryFactory queryFactory, JdbcTemplate jdbcTemplate) {
super(ModelMetricsTrainEntity.class);
this.queryFactory = queryFactory;
this.jdbcTemplate = jdbcTemplate;
}
@Override
public List<ResponsePathDto> getTrainMetricSaveNotYetModelIds() {
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step1EndDttm.isNotNull(),
modelMasterEntity.step1State.eq(TrainStatusType.COMPLETED.getId()),
modelMasterEntity
.step1MetricSaveYn
.isNull()
.or(modelMasterEntity.step1MetricSaveYn.isFalse()))
.fetch();
}
@Override
public void insertModelMetricsTrain(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_metrics_train
(model_id, epoch, iteration, loss, lr, duration_time)
values (?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
public void updateModelMetricsTrainSaveYn(Long modelId, String stepNo) {
queryFactory
.update(modelMasterEntity)
.set(
stepNo.equals("step1")
? modelMasterEntity.step1MetricSaveYn
: modelMasterEntity.step2MetricSaveYn,
true)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
@Override
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
String sql =
"""
insert into tb_model_metrics_validation
(model_id, epoch, a_acc, m_fscore, m_precision, m_recall, m_iou, m_acc, changed_fscore, changed_precision, changed_recall,
unchanged_fscore, unchanged_precision, unchanged_recall
)
values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""";
jdbcTemplate.batchUpdate(sql, batchArgs);
}
@Override
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
queryFactory
.update(modelMasterEntity)
.set(modelMasterEntity.bestEpoch, epoch)
.where(modelMasterEntity.id.eq(modelId))
.execute();
}
}

View File

@@ -0,0 +1,216 @@
package com.kamco.cd.training.train;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.train.service.DataSetCountersService;
import com.kamco.cd.training.train.service.TestJobService;
import com.kamco.cd.training.train.service.TrainJobService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Tag(name = "학습 실행 API", description = "모델학습관리 > 학습 실행 API")
@RequiredArgsConstructor
@RestController
@RequestMapping("/api/train")
public class TrainApiController {
private final TrainJobService trainJobService;
private final TestJobService testJobService;
private final DataSetCountersService dataSetCountersService;
@Operation(summary = "학습 실행", description = "학습 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/run/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.enqueue(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 재실행", description = "학습 재실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "재실행 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/restart/{uuid}")
public ApiResponseDto<String> restart(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.restart(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 이어하기", description = "학습 이어하기 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "이어하기 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/resume/{uuid}")
public ApiResponseDto<String> resume(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
Long jobId = trainJobService.resume(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/cancel/{uuid}")
public ApiResponseDto<String> cancel(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
trainJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 실행", description = "test 실행 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "test 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/run/{epoch}/{uuid}")
public ApiResponseDto<String> run(
@Parameter(description = "best 에폭", example = "1") @PathVariable int epoch,
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.enqueue(modelId, uuid, epoch);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "test 학습 취소", description = "학습 취소 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "취소 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/test/cancel/{uuid}")
public ApiResponseDto<String> cancelTest(
@Parameter(description = "uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
testJobService.cancel(modelId);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "데이터셋 tmp 파일생성", description = "데이터셋 tmp 파일생성 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@PostMapping("/create-tmp/{uuid}")
public ApiResponseDto<UUID> createTmpFile(
@Parameter(description = "model uuid", example = "80a0e544-36ed-4999-b705-97427f23337d")
@PathVariable
UUID uuid) {
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
}
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "데이터셋 tmp 파일생성 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping(path = "/counts/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE)
public ApiResponseDto<String> getCount(
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
@PathVariable
UUID uuid) {
Long modelId = trainJobService.getModelIdByUuid(uuid);
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
}
}

View File

@@ -0,0 +1,22 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class EvalRunRequest {
private String uuid;
private int epoch; // best_changed_fscore_epoch_1.pth
private Integer timeoutSeconds;
private String datasetFolder;
private String outputFolder;
public String getOutputFolder() {
return this.outputFolder.toString();
}
}

View File

@@ -0,0 +1,25 @@
package com.kamco.cd.training.train.dto;
import java.time.ZonedDateTime;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public class ModelTrainJobDto {
private Long id;
private Long modelId;
private Integer attemptNo;
private String statusCd;
private Integer exitCode;
private String errorMessage;
private String containerName;
private Map<String, Object> paramsJson;
private ZonedDateTime queuedDttm;
private ZonedDateTime startedDttm;
private ZonedDateTime finishedDttm;
private Integer totalEpoch;
private Integer currentEpoch;
}

View File

@@ -0,0 +1,15 @@
package com.kamco.cd.training.train.dto;
/** 학습 실행이 예약되었음을 알리는 이벤트 객체 */
public class ModelTrainJobQueuedEvent {
private final Long jobId;
public ModelTrainJobQueuedEvent(Long jobId) {
this.jobId = jobId;
}
public Long getJobId() {
return jobId;
}
}

View File

@@ -0,0 +1,24 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class ModelTrainLinkDto {
private Long modelId;
private String trainType;
private String modelNo;
private Long datasetId;
private String targetClassCd;
private String comparePath;
private String targetPath;
private String labelPath;
private String geoJsonPath;
private String datasetUid;
}

View File

@@ -0,0 +1,59 @@
package com.kamco.cd.training.train.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Properties;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class ModelTrainMetricsDto {
@Schema(name = "ResponsePathDto", description = "AI 결과 저장된 path 경로 정보")
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class ResponsePathDto {
private Long modelId;
private String responsePath;
private UUID uuid;
}
@Getter
@AllArgsConstructor
public static class ModelMetricJsonDto {
@JsonProperty("cd_model_type")
private String cdModelType;
@JsonProperty("model_version")
private String modelVersion;
private Properties properties;
}
@Getter
@AllArgsConstructor
public static class Properties {
@JsonProperty("f1_score")
private Float f1Score;
private Float precision;
private Float recall;
private Float loss;
private Double iou;
}
@Getter
@AllArgsConstructor
public static class ModelTestFileName {
private String bestEpochFileName;
private String modelVersion;
}
}

View File

@@ -0,0 +1,94 @@
package com.kamco.cd.training.train.dto;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunRequest {
// ========================
// 기본
// ========================
private String datasetFolder;
private UUID outputFolder;
private String inputSize;
private String cropSize;
private Integer batchSize;
private String gpuIds;
private Integer gpus;
private Double learningRate;
private String backbone;
private Integer epochs;
// ========================
// Data
// ========================
private Integer trainNumWorkers;
private Integer valNumWorkers;
private Integer testNumWorkers;
private Boolean trainShuffle;
private Boolean trainPersistent;
private Boolean valPersistent;
// ========================
// Model Architecture
// ========================
private Double dropPathRate;
private Integer frozenStages;
private String neckPolicy;
private String classWeight;
private String decoderChannels;
// ========================
// Loss & Optimization
// ========================
private Double weightDecay;
private Double layerDecayRate;
private Integer ignoreIndex;
private Boolean ddpFindUnusedParams;
private Integer numLayers;
// ========================
// Evaluation
// ========================
private String metrics;
private String saveBest;
private String saveBestRule;
private Integer valInterval;
private Integer logInterval;
private Integer visInterval;
// ========================
// Augmentation
// ========================
private Double rotProb;
private String rotDegree;
private Double flipProb;
private Double exchangeProb;
private Integer brightnessDelta;
private String contrastRange;
private String saturationRange;
private Integer hueDelta;
// ========================
// 실행 타임아웃
// ========================
private Integer timeoutSeconds;
private String resumeFrom;
private UUID uuid;
public String getOutputFolder() {
return String.valueOf(this.outputFolder);
}
public String getUuid() {
return String.valueOf(this.uuid);
}
}

View File

@@ -0,0 +1,20 @@
package com.kamco.cd.training.train.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
/** 학습 실행 결과 반환 객체 */
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class TrainRunResult {
private String jobId;
private String containerName;
private int exitCode;
private String status;
private String logs;
}

View File

@@ -0,0 +1,228 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
/** 학습실행 파일 하드링크 */
@Service
@Log4j2
@RequiredArgsConstructor
public class DataSetCountersService {
private final ModelTrainMngCoreService modelTrainMngCoreService;
@Value("${train.docker.requestDir}")
private String requestDir;
@Value("${train.docker.basePath}")
private String trainBaseDir;
public String getCount(Long modelId) {
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
StringBuilder allLogs = new StringBuilder();
try {
// request 폴더
for (String uid : uids) {
Path path = Path.of(requestDir, uid);
DatasetCounters counters = countTmpAfterBuild(path);
allLogs.append(counters.prints(uid, "REQUEST")).append(System.lineSeparator());
}
// tmp
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
// 차이나는거
diffMergedRequestsVsTmp(uids, tmpPath);
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
allLogs
.append(counters2.prints(basic.getRequestPath(), "TMP"))
.append(System.lineSeparator());
} catch (IOException e) {
log.error(e.getMessage());
}
return allLogs.toString();
}
private int countTif(Path dir) throws IOException {
if (!Files.isDirectory(dir)) return 0;
try (var stream = Files.walk(dir)) {
return (int)
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
}
/*
대소문자 및 geojson 필요시
* try (var stream = Files.walk(dir)) {
return (int)
stream
.filter(Files::isRegularFile)
.filter(p -> {
String name = p.getFileName().toString().toLowerCase();
return name.endsWith(".tif") || name.endsWith(".geojson");
})
.count();
}
* */
}
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
// input1
int in1Train = countTif(path.resolve("train/input1"));
int in1Val = countTif(path.resolve("val/input1"));
int in1Test = countTif(path.resolve("test/input1"));
// input2
int in2Train = countTif(path.resolve("train/input2"));
int in2Val = countTif(path.resolve("val/input2"));
int in2Test = countTif(path.resolve("test/input2"));
List<DatasetCounter> input1List = new ArrayList<>();
List<DatasetCounter> input2List = new ArrayList<>();
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
return new DatasetCounters(input1List, input2List);
}
@Getter
public static class DatasetCounter {
private int inputTrain = 0;
private int inputTest = 0;
private int inputVal = 0;
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
this.inputTrain = inputTrain;
this.inputTest = inputTest;
this.inputVal = inputVal;
}
}
@Getter
public static class DatasetCounters {
private List<DatasetCounter> input1 = new ArrayList<>();
private List<DatasetCounter> input2 = new ArrayList<>();
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
this.input1 = input1;
this.input2 = input2;
}
public String prints(String uuid, String type) {
int train = 0, test = 0, val = 0;
int train2 = 0, test2 = 0, val2 = 0;
for (DatasetCounter datasetCounter : input1) {
train += datasetCounter.inputTrain;
test += datasetCounter.inputTest;
val += datasetCounter.inputVal;
}
for (DatasetCounter datasetCounter : input2) {
train2 += datasetCounter.inputTrain;
test2 += datasetCounter.inputTest;
val2 += datasetCounter.inputVal;
}
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
log.info(
"*total* = train : {} | val : {} | test : {} ",
train + train2,
val + val2,
test + test2);
return String.format(
"======== UUID FOLDER COUNT %s : %s%n"
+ "input 1 = train : %s | val : %s | test : %s%n"
+ "input 2 = train : %s | val : %s | test : %s%n"
+ "*total* = train : %s | val : %s | test : %s",
type,
uuid,
train,
val,
test,
train2,
val2,
test2,
train + train2,
val + val2,
test + test2);
}
}
private Set<String> listTifRelative(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> root.relativize(p).toString().replace("\\", "/"))
.collect(Collectors.toSet());
}
}
private Set<String> listTifFileNameOnly(Path root) throws IOException {
if (!Files.isDirectory(root)) return Set.of();
try (var stream = Files.walk(root)) {
return stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().toLowerCase().endsWith(".tif"))
.map(p -> p.getFileName().toString()) // 파일명만
.collect(Collectors.toSet());
}
}
public void diffMergedRequestsVsTmp(List<String> uids, Path tmpRoot) throws IOException {
// 1) 요청 uids 전체를 합친 tif "파일명" 집합
Set<String> reqAll = new HashSet<>();
for (String uid : uids) {
Path reqRoot = Path.of(requestDir, uid);
// ★합본 tmp는 보통 폴더 구조가 바뀌므로 "상대경로" 비교보다 파일명 비교가 먼저 유용합니다.
reqAll.addAll(listTifFileNameOnly(reqRoot));
}
// 2) tmp tif 파일명 집합
Set<String> tmpAll = listTifFileNameOnly(tmpRoot);
Set<String> missing = new HashSet<>(reqAll);
missing.removeAll(tmpAll);
Set<String> extra = new HashSet<>(tmpAll);
extra.removeAll(reqAll);
log.info("==== MERGED DIFF (filename-based) ====");
log.info("request(all uids) tif = {}", reqAll.size());
log.info("tmp tif = {}", tmpAll.size());
log.info("missing = {}", missing.size());
log.info("extra = {}", extra.size());
missing.stream().sorted().limit(50).forEach(f -> log.warn("[MISSING] {}", f));
extra.stream().sorted().limit(50).forEach(f -> log.warn("[EXTRA] {}", f));
}
}

View File

@@ -0,0 +1,486 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Log4j2
@Service
@RequiredArgsConstructor
public class DockerTrainService {
// 실행할 Docker 이미지명
@Value("${train.docker.image}")
private String image;
// 학습 요청 데이터가 위치한 호스트 디렉토리
@Value("${train.docker.requestDir}")
private String requestDir;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
// 컨테이너 이름 prefix
@Value("${train.docker.containerPrefix}")
private String containerPrefix;
// 공유메모리 사이즈 설정 (대용량 학습시 필요)
@Value("${train.docker.shmSize:16g}")
private String shmSize;
// IPC host 사용 여부
@Value("${train.docker.ipcHost:true}")
private boolean ipcHost;
private final ModelTrainJobCoreService modelTrainJobCoreService;
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param req
* @param containerName
* @return
* @throws Exception
*/
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req);
log.info("=== Docker Train Command ===");
log.info("Container: {}", containerName);
log.info("Command: {}", String.join(" ", cmd));
log.info("================================");
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder logBuilder = new StringBuilder();
Process p = pb.start();
log.info("[TRAIN-BOOT] docker run started. container={}", containerName);
try {
log.info("[TRAIN-BOOT] pid={}", p.pid()); // Java 9+
} catch (Throwable ignore) {
}
try {
// 바로 죽었는지 100ms만 체크
if (p.waitFor(100, TimeUnit.MILLISECONDS)) {
int exit = p.exitValue();
String earlyLogs;
synchronized (logBuilder) {
earlyLogs = logBuilder.toString();
}
log.error(
"[TRAIN-BOOT] docker run exited immediately. container={} exit={}",
containerName,
exit);
log.error("[TRAIN-BOOT] early logs:\n{}", earlyLogs);
} else {
log.info("[TRAIN-BOOT] docker run is still running. container={}", containerName);
}
} catch (Exception e) {
log.warn("[TRAIN-BOOT] early-exit check failed: {}", e.toString(), e);
}
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
// 너무 잦은 업데이트 방지용
AtomicInteger lastEpoch = new AtomicInteger(0);
AtomicInteger lastIter = new AtomicInteger(0);
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (logBuilder) {
logBuilder.append(line).append('\n');
}
Matcher m = epochPattern.matcher(line);
if (m.find()) {
int epoch = Integer.parseInt(m.group(1));
int iter = Integer.parseInt(m.group(2));
int totalIter = Integer.parseInt(m.group(3));
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
// 쓰로틀링: 에폭 끝 or 10 iter마다
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
// 중복 방지
if (shouldUpdate) {
int prevEpoch = lastEpoch.get();
int prevIter = lastIter.get();
if (epoch != prevEpoch || iter != prevIter) {
lastEpoch.set(epoch);
lastIter.set(iter);
log.info(
"[TRAIN] container={} epoch={} iter={}/{}",
containerName,
epoch,
iter,
totalIter);
modelTrainJobCoreService.updateEpoch(containerName, epoch);
}
}
}
}
} catch (Exception e) {
log.warn("logThread error: {}", e.toString());
}
},
"train-log-" + containerName);
logThread.setDaemon(true);
logThread.start();
int timeoutSeconds = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
boolean finished = p.waitFor(timeoutSeconds, TimeUnit.SECONDS);
if (!finished) {
// docker run 프로세스도 같이 끊어야 readLine이 풀림
p.destroy();
if (!p.waitFor(2, TimeUnit.SECONDS)) {
p.destroyForcibly();
}
killContainer(containerName);
String logs;
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(
null, // jobId (없으면 null)
containerName,
-1,
"TIMEOUT",
logs);
}
int exit = p.exitValue();
// 로그 스레드가 마무리할 시간을 조금 줌(없어도 되지만 로그 누락 방지용)
logThread.join(500);
String logs;
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
}
/**
* 학습 docker run command
*
* @param containerName
* @param req
* @return
*/
private List<String> buildDockerRunCommand(String containerName, TrainRunRequest req) {
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName);
// 실행 종료 시 자동 삭제
c.add("--rm");
// GPU 전체 사용
c.add("--gpus");
c.add("all");
// IPC host 사용 여부
if (ipcHost) {
c.add("--ipc=host");
}
// 공유메모리 설정
c.add("--shm-size=" + shmSize);
// 메모리 관련 ulimit 설정
c.add("--ulimit");
c.add("memlock=-1");
c.add("--ulimit");
c.add("stack=67108864");
// 환경변수 설정
c.add("-e");
c.add("OPENCV_LOG_LEVEL=ERROR");
c.add("-e");
c.add("NCCL_DEBUG=INFO");
c.add("-e");
c.add("NCCL_IB_DISABLE=1");
c.add("-e");
c.add("NCCL_P2P_DISABLE=0");
c.add("-e");
c.add("NCCL_SOCKET_IFNAME=eth0");
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
// 표준입력 유지 (-it 대신 -i만 사용)
c.add("-i");
// 사용할 이미지
c.add(image);
// ===== 컨테이너 내부 실행 명령 =====
c.add("python");
c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds()); // null
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
// ===== Data =====
addArg(c, "--train-num-workers", req.getTrainNumWorkers());
addArg(c, "--val-num-workers", req.getValNumWorkers());
addArg(c, "--test-num-workers", req.getTestNumWorkers());
addArg(c, "--train-shuffle", req.getTrainShuffle());
addArg(c, "--train-persistent", req.getTrainPersistent());
addArg(c, "--val-persistent", req.getValPersistent());
// ===== Model Architecture =====
addArg(c, "--drop-path-rate", req.getDropPathRate());
addArg(c, "--frozen-stages", req.getFrozenStages());
addArg(c, "--neck-policy", req.getNeckPolicy());
addArg(c, "--class-weight", req.getClassWeight());
addArg(c, "--decoder-channels", req.getDecoderChannels());
// ===== Loss & Optimization =====
addArg(c, "--weight-decay", req.getWeightDecay());
addArg(c, "--layer-decay-rate", req.getLayerDecayRate());
addArg(c, "--ignore-index", req.getIgnoreIndex());
addArg(c, "--ddp-find-unused-params", req.getDdpFindUnusedParams());
addArg(c, "--num-layers", req.getNumLayers());
// ===== Evaluation =====
addArg(c, "--metrics", req.getMetrics());
addArg(c, "--save-best", req.getSaveBest());
addArg(c, "--save-best-rule", req.getSaveBestRule());
addArg(c, "--val-interval", req.getValInterval());
addArg(c, "--log-interval", req.getLogInterval());
addArg(c, "--vis-interval", req.getVisInterval());
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
addArg(c, "--hue-delta", req.getHueDelta());
if (req.getResumeFrom() != null && !req.getResumeFrom().isBlank()) {
c.add("--resume");
addArg(c, "--load-from", req.getResumeFrom());
}
addArg(c, "--save-interval", 1);
return c;
}
/** 인자 추가(키 + 값) - null / blank면 아예 추가 안 함 */
private void addArg(List<String> c, String key, Object value) {
if (value == null) return;
String s = String.valueOf(value).trim();
if (s.isEmpty()) return;
c.add(key + "=" + s);
}
/** 컨테이너 강제 종료 및 제거 */
public void killContainer(String containerName) {
try {
new ProcessBuilder("docker", "rm", "-f", containerName)
.redirectErrorStream(true)
.start()
.waitFor(10, TimeUnit.SECONDS);
} catch (Exception ignored) {
}
}
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param containerName
* @param req
* @return
* @throws Exception
*/
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);
log.info("=== Docker Test Command ===");
log.info("Container: {}", containerName);
log.info("Command: {}", String.join(" ", cmd));
log.info("================================");
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectErrorStream(true);
Process p = pb.start();
StringBuilder log = new StringBuilder();
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n');
}
}
} catch (Exception ignored) {
}
});
logThread.setDaemon(true);
logThread.start();
int timeout = req.getTimeoutSeconds() != null ? req.getTimeoutSeconds() : 7200;
boolean finished = p.waitFor(timeout, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
killContainer(containerName);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, -1, "TIMEOUT", logs);
}
int exit = p.exitValue();
logThread.join(500);
String logs;
synchronized (log) {
logs = log.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
}
/**
* 테스트 docker run command
*
* @param containerName
* @param req
* @return
*/
private List<String> buildDockerEvalCommand(String containerName, EvalRunRequest req) {
String uuid = req.getUuid();
Integer epoch = req.getEpoch();
if (uuid == null || uuid.isBlank()) throw new IllegalArgumentException("uuid is required");
if (epoch == null || epoch <= 0) throw new IllegalArgumentException("epoch must be > 0");
Path epochPath = Paths.get(responseDir, req.getOutputFolder());
// 결과 폴더에 파라미터로 받은 베스트 epoch이 best_changed_fscore_epoch_ 로 시작하는 파일이 있는지 확인 후 pth 파일명 반환
String modelFile = findCheckpoint(epochPath, epoch);
List<String> c = new ArrayList<>();
c.add("docker");
c.add("run");
c.add("--rm");
c.add("--gpus");
c.add("all");
c.add("--ipc=host");
c.add("--shm-size=" + shmSize);
c.add("-v");
c.add("/home/kcomu/data" + "/tmp:/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
c.add("kamco-cd-train:latest");
c.add("python");
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--output-folder", req.getOutputFolder());
c.add("--epoch");
c.add(modelFile);
return c;
}
public String findCheckpoint(Path dir, int epoch) {
String bestFileName = String.format("best_changed_fscore_epoch_%d.pth", epoch);
String normalFileName = String.format("epoch_%d.pth", epoch);
Path bestPath = dir.resolve(bestFileName);
Path normalPath = dir.resolve(normalFileName);
// 1. best 파일이 존재하면 그거 사용
if (Files.isRegularFile(bestPath)) {
return bestFileName;
}
// 2. 없으면 일반 epoch 파일 사용
if (Files.isRegularFile(normalPath)) {
return normalFileName;
}
throw new IllegalStateException("Checkpoint 파일이 없습니다. epoch=" + epoch);
}
}

View File

@@ -0,0 +1,449 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
/**
* 서버 재기동 시 "RUNNING 상태로 남아있는 학습 Job"을 복구(정리)하기 위한 서비스.
*
* <p>상황 예시: - 서버가 강제 재기동/장애로 내려감 - DB 상에서는 job_state가 RUNNING(진행중)으로 남아있음 - 실제 docker 컨테이너는: 1) 아직
* 살아있거나(running=true) 2) 종료되었거나(exited) 3) --rm 옵션으로 인해 컨테이너가 이미 삭제되어 존재하지 않을 수 있음
*
* <p>이 클래스는 ApplicationReadyEvent(스프링 부팅 완료) 시점에 실행되어, DB의 RUNNING 잡들을 조회한 뒤 컨테이너 상태를 점검하고,
* SUCCESS/FAILED 처리를 수행합니다.
*/
@Profile("!local")
@Component
@RequiredArgsConstructor
@Log4j2
public class JobRecoveryOnStartupService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
/**
* Docker 컨테이너가 쓰는 response(산출물) 디렉토리의 "호스트 측" 베이스 경로. 예) /data/train/response
*
* <p>컨테이너가 --rm 으로 삭제된 경우에도 이 경로에 val.csv / *.pth 등이 남아있으면 정상 종료 여부를 "파일 기반"으로 판정합니다.
*/
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 스프링 부팅 완료 시점(빈 생성/초기화 모두 끝난 뒤)에 복구 로직 실행.
*
* <p>@Transactional: - recover() 메서드 전체가 하나의 트랜잭션으로 감싸집니다. - Job 하나씩 처리하다가 예외가 발생하면 전체 롤백이 될 수
* 있으므로 "잡 단위로 확실히 커밋"이 필요하면 (권장) 잡 단위로 분리 트랜잭션(REQUIRES_NEW) 고려하세요.
*/
// @EventListener(ApplicationReadyEvent.class)
@Transactional
public void recover() {
// 1) DB에서 "RUNNING(진행중) 상태"로 남아있는 job 목록을 조회
List<ModelTrainJobDto> runningJobs = modelTrainJobCoreService.findRunningJobs();
// 실행중 job이 없으면 할 일 없음
if (runningJobs == null || runningJobs.isEmpty()) {
return;
}
// 2) 각 job에 대해 docker 컨테이너 상태를 확인하고, 상태에 따라 조치
for (ModelTrainJobDto job : runningJobs) {
String containerName = job.getContainerName();
try {
// 2-1) docker inspect로 컨테이너 상태 조회
DockerInspectState state = inspectContainer(containerName);
// 3) 컨테이너가 "없음"
// - docker run --rm 로 실행한 컨테이너는 정상 종료 시 바로 삭제될 수 있음
// - 즉 "컨테이너 없음"이 무조건 실패는 아님
if (!state.exists()) {
log.warn(
"[RECOVERY] container missing. try file-based reconcile. container={}",
containerName);
// 3-1) 컨테이너가 없을 때는 산출물(responseDir)을 보고 완료 여부를 "추정"
OutputResult out = probeOutputs(job);
// 3-2) 산출물이 충분하면 성공 처리
if (out.completed()) {
log.info("[RECOVERY] outputs look completed. mark SUCCESS. jobId={}", job.getId());
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 3-3) 산출물이 부족하면 실패 처리(운영 정책에 따라 "유예"도 가능)
log.warn(
"[RECOVERY] outputs incomplete. mark FAILED. jobId={} reason={}",
job.getId(),
out.reason());
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_MISSING_OUTPUT_INCOMPLETE");
markStepErrorByJobType(job, out.reason());
}
continue;
}
// 4) 컨테이너는 존재하고, 아직 running=true
// - 서버만 재기동됐고 컨테이너는 그대로 살아있는 케이스
// - 이 경우 DB를 건드리면 오히려 꼬일 수 있으니 RUNNING 유지
if (state.running()) {
log.info("[RECOVERY] container still running. container={}", containerName);
try {
ProcessBuilder pb = new ProcessBuilder("docker", "stop", "-t", "20", containerName);
pb.redirectErrorStream(true);
Process p = pb.start();
boolean finished = p.waitFor(30, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker stop timeout");
}
int code = p.exitValue();
if (code != 0) {
throw new IOException("docker stop failed. exit=" + code);
}
log.info(
"[RECOVERY] container stopped (will be auto removed by --rm). container={}",
containerName);
// 여기서 상태를 PAUSED로 바꿔도 되고
modelTrainJobCoreService.markPaused(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
} catch (Exception e) {
log.error("[RECOVERY] docker stop failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(job.getId(), -1, "AUTO_STOP_FAILED_ON_RESTART");
}
continue;
}
// 5) 컨테이너는 존재하지만 running=false
// - exited / dead 등의 상태
Integer exitCode = state.exitCode();
String status = state.status();
// 5-1) exitCode=0이면 정상 종료로 간주 → SUCCESS 처리
if (exitCode != null && exitCode == 0) {
log.info("[RECOVERY] container exited(0). mark SUCCESS. container={}", containerName);
modelTrainJobCoreService.markSuccess(job.getId(), 0);
markStepSuccessByJobType(job);
} else {
// 5-2) exitCode != 0 이거나 null이면 실패로 간주 → FAILED 처리
log.warn(
"[RECOVERY] container exited non-zero. mark FAILED. container={} status={} exitCode={}",
containerName,
status,
exitCode);
modelTrainJobCoreService.markFailed(
job.getId(), exitCode, "SERVER_RESTART_CONTAINER_EXIT_NONZERO");
markStepErrorByJobType(job, "exit=" + exitCode + " status=" + status);
}
} catch (Exception e) {
// 6) docker inspect 자체가 실패한 경우
// - docker 데몬 문제/권한 문제/일시적 오류 가능
// - 운영 정책에 따라 "바로 실패" 대신 "유예" 처리도 고려 가능
log.error("[RECOVERY] container inspect failed. container={}", containerName, e);
modelTrainJobCoreService.markFailed(
job.getId(), -1, "SERVER_RESTART_CONTAINER_INSPECT_ERROR");
markStepErrorByJobType(job, "inspect-error");
}
}
}
/**
* jobType에 따라 학습 관리 테이블의 "성공 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 성공 - 그 외 → step1(학습 단계) 성공
*/
private void markStepSuccessByJobType(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Success(job.getModelId());
} else {
modelTrainMngCoreService.markStep1Success(job.getModelId());
}
}
/**
* jobType에 따라 학습 관리 테이블의 "에러 단계"를 업데이트.
*
* <p>예: - jobType == "EVAL" → step2(평가 단계) 에러 - 그 외 → step1 혹은 전체 에러
*/
private void markStepErrorByJobType(ModelTrainJobDto job, String msg) {
Map<String, Object> params = job.getParamsJson();
boolean isEval = params != null && "EVAL".equals(String.valueOf(params.get("jobType")));
if (isEval) {
modelTrainMngCoreService.markStep2Error(job.getModelId(), msg);
} else {
modelTrainMngCoreService.markError(job.getModelId(), msg);
}
}
/**
* docker inspect를 사용해서 컨테이너 상태를 조회합니다.
*
* <p>사용하는 템플릿: {{.State.Status}} {{.State.Running}} {{.State.ExitCode}}
*
* <p>예상 출력 예: - "running true 0" - "exited false 0" - "exited false 137"
*
* <p>주의: - 컨테이너가 없거나 inspect 실패 시 exitCode != 0 또는 output이 비어서 missing() 반환 - 무한 대기 방지를 위해 5초
* 타임아웃을 둠
*/
private DockerInspectState inspectContainer(String containerName)
throws IOException, InterruptedException {
ProcessBuilder pb =
new ProcessBuilder(
"docker",
"inspect",
"-f",
"{{.State.Status}} {{.State.Running}} {{.State.ExitCode}}",
containerName);
// stderr를 stdout으로 합쳐서 한 스트림으로 읽기(에러 메시지도 함께 받음)
pb.redirectErrorStream(true);
Process p = pb.start();
// inspect 출력은 1줄이면 충분하므로 readLine()만 수행
String output;
try (BufferedReader br = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
output = br.readLine();
}
// 무한대기 방지: 5초 내에 종료되지 않으면 강제 종료
boolean finished = p.waitFor(5, TimeUnit.SECONDS);
if (!finished) {
p.destroyForcibly();
throw new IOException("docker inspect timeout");
}
// docker inspect 자체의 프로세스 exit code
int code = p.exitValue();
// 실패(코드 !=0) 또는 출력이 없으면 "컨테이너 없음"으로 간주
if (code != 0 || output == null || output.isBlank()) {
return DockerInspectState.missing();
}
// "status running exitCode" 형태로 split
String[] parts = output.trim().split("\\s+");
// status: running/exited/dead 등
String status = parts.length > 0 ? parts[0] : "unknown";
// running: true/false
boolean running = parts.length > 1 && Boolean.parseBoolean(parts[1]);
// exitCode: 정수 파싱(파싱 실패하면 null)
Integer exitCode = null;
if (parts.length > 2) {
try {
exitCode = Integer.parseInt(parts[2]);
} catch (Exception ignore) {
// ignore
}
}
return new DockerInspectState(true, running, exitCode, status);
}
/**
* docker inspect 결과를 담는 레코드.
*
* <p>exists: - true : docker inspect 성공 (컨테이너 존재) - false : 컨테이너 없음(또는 inspect 실패를 missing으로 간주)
*/
private record DockerInspectState(
boolean exists, boolean running, Integer exitCode, String status) {
static DockerInspectState missing() {
return new DockerInspectState(false, false, null, "missing");
}
}
// ============================================================================================
// 컨테이너가 "없을 때" 파일 기반으로 완료/미완료를 판정하는 로직
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 산출물만 남아있는 상태) 완료 여부를 파일 기반으로 판정합니다.
*
* <p>판정 규칙(보수적으로 설계): 1) total_epoch가 paramsJson에 있어야 함 (없으면 완료 판단 불가) 2) val.csv 존재 + 헤더 제외 라인 수
* >= total_epoch 이어야 함 3) *.pth 파일이 total_epoch 이상 존재하거나, best*.pth(또는 *best*.pth)가 존재해야 함
*
* <p>왜 이렇게? - 어떤 학습은 epoch마다 pth를 남기고 - 어떤 학습은 best만 남기기도 해서 "pthCount >= total_epoch"만 쓰면 정상 종료를
* 실패로 오판할 수 있음.
*/
private OutputResult probeOutputs(ModelTrainJobDto job) {
try {
Path outDir = resolveOutputDir(job);
if (outDir == null || !Files.isDirectory(outDir)) {
return new OutputResult(false, "output-dir-missing");
}
Integer totalEpoch = extractTotalEpoch(job).orElse(null);
if (totalEpoch == null || totalEpoch <= 0) {
return new OutputResult(false, "total-epoch-missing");
}
Path valCsv = outDir.resolve("val.csv");
if (!Files.exists(valCsv)) {
return new OutputResult(false, "val.csv-missing");
}
long lines = countNonHeaderLines(valCsv);
// “같아야 완료” 정책
if (lines == totalEpoch) {
return new OutputResult(true, "ok");
}
return new OutputResult(
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
} catch (Exception e) {
log.error("[RECOVERY] probeOutputs error. jobId={}", job.getId(), e);
return new OutputResult(false, "probe-error");
}
}
/**
* responseDir 아래에서 job 산출물 디렉토리를 찾습니다.
*
* <p>가장 중요한 커스터마이징 포인트: - 실제 운영 환경에서 산출물이 어떤 경로 규칙으로 저장되는지에 따라 여기만 수정하면 됩니다.
*
* <p>현재 기본 탐색 순서: 1) {responseDir}/{jobId} 2) {responseDir}/{modelId} 3)
* {responseDir}/{containerName} 4) 마지막 fallback: responseDir 자체
*
* <p>추천: - 여러분 규칙이 "{responseDir}/{modelId}/{jobId}" 같은 형태라면 base.resolve(modelId).resolve(jobId)
* 형태를 1순위로 두세요.
*/
private Path resolveOutputDir(ModelTrainJobDto job) {
ModelTrainMngDto.Basic model = modelTrainMngCoreService.findModelById(job.getModelId());
Path base = Paths.get(responseDir, model.getUuid().toString(), "metrics");
return Files.isDirectory(base) ? base : null;
}
/**
* paramsJson에서 total_epoch 값을 추출합니다.
*
* <p>키 후보: - "total_epoch" (snake_case) - "totalEpoch" (camelCase)
*
* <p>예: paramsJson = {"jobType":"TRAIN","total_epoch":50,...}
*/
private Optional<Integer> extractTotalEpoch(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("total_epoch");
if (v == null) v = params.get("totalEpoch");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
/**
* CSV 파일에서 "헤더(첫 줄)"를 제외한 라인 수를 계산합니다.
*
* <p>가정: - val.csv 첫 줄은 헤더 - 이후 라인들이 epoch별 기록(또는 유사한 누적 기록)
*
* <p>주의: - 파일 인코딩은 UTF-8로 가정 - 빈 줄은 제외
*/
private long countNonHeaderLines(Path csv) throws IOException {
try (Stream<String> lines = Files.lines(csv, StandardCharsets.UTF_8)) {
return lines.skip(1).filter(s -> s != null && !s.isBlank()).count();
}
}
/**
* 디렉토리에서 glob 패턴에 맞는 파일 수를 셉니다.
*
* <p>예: - "*.pth" - "best*.pth"
*/
private long countFilesByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
long cnt = 0;
for (Path p : ds) {
if (Files.isRegularFile(p)) cnt++;
}
return cnt;
}
}
/** 디렉토리에서 glob 패턴에 맞는 파일이 "하나라도" 존재하는지 체크합니다. */
private boolean existsByGlob(Path dir, String glob) throws IOException {
try (DirectoryStream<Path> ds = Files.newDirectoryStream(dir, glob)) {
return ds.iterator().hasNext();
}
}
// ============================================================================================
// probeOutputs() 결과 객체
// ============================================================================================
/**
* 컨테이너가 없을 때(responseDir 기반) 완료 여부 판정 결과.
*
* <p>completed: - true : 산출물이 완료로 보임(성공 처리 가능) - false : 산출물이 부족/불명확(실패 또는 유예 판단)
*
* <p>reason: - 실패/미완료 사유(로그/DB 메시지로 남기기 용도)
*/
private static final class OutputResult {
private final boolean completed;
private final String reason;
private OutputResult(boolean completed, String reason) {
this.completed = completed;
this.reason = reason;
}
boolean completed() {
return completed;
}
String reason() {
return reason;
}
}
}

View File

@@ -0,0 +1,216 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class ModelTestMetricsJobService {
private final ModelTestMetricsJobCoreService modelTestMetricsJobCoreService;
@Value("${spring.profiles.active}")
private String profile;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
@Value("${file.pt-path}")
private String ptPathDir;
/** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() {
List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
if (modelIds.isEmpty()) {
return;
}
for (ResponsePathDto modelInfo : modelIds) {
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
String model = record.get("model");
long TP = Long.parseLong(record.get("TP"));
long FP = Long.parseLong(record.get("FP"));
long FN = Long.parseLong(record.get("FN"));
float precision = Float.parseFloat(record.get("precision"));
float recall = Float.parseFloat(record.get("recall"));
float f1_score = Float.parseFloat(record.get("f1_score"));
float accuracy = Float.parseFloat(record.get("accuracy"));
float iou = Float.parseFloat(record.get("iou"));
long detection_count = Long.parseLong(record.get("detection_count"));
long gt_count = Long.parseLong(record.get("gt_count"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
model,
TP,
FP,
FN,
precision,
recall,
f1_score,
accuracy,
iou,
detection_count,
gt_count
});
}
modelTestMetricsJobCoreService.insertModelMetricsTest(batchArgs);
// test.csv 파일 읽어서 저장한 여부로만 사용하기
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step2");
} catch (IOException e) {
throw new RuntimeException(e);
}
// 패키징할 파일 만들기
modelTestMetricsJobCoreService.updatePackingStart(
modelInfo.getModelId(), ZonedDateTime.now());
ModelMetricJsonDto jsonDto =
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
try {
writeJsonFile(
jsonDto,
Paths.get(
responseDir
+ "/"
+ modelInfo.getUuid()
+ "/"
+ jsonDto.getModelVersion()
+ ".json"));
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
ModelTestFileName fileInfo =
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
Path zipPath =
Paths.get(
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
Set<String> targetNames =
Set.of(
"model_config.py",
fileInfo.getBestEpochFileName() + ".pth",
fileInfo.getModelVersion() + ".json");
List<Path> files = new ArrayList<>();
try (Stream<Path> s = Files.list(responsePath)) {
files.addAll(
s.filter(Files::isRegularFile)
.filter(p -> targetNames.contains(p.getFileName().toString()))
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
files.addAll(
s.filter(Files::isRegularFile)
.limit(1) // yolov8_6th-6m.pt 파일 1개만
.collect(Collectors.toList()));
} catch (IOException e) {
throw new RuntimeException(e);
}
try {
zipFiles(files, zipPath);
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.COMPLETED.getId());
} catch (IOException e) {
modelTestMetricsJobCoreService.updatePackingEnd(
modelInfo.getModelId(), ZonedDateTime.now(), TrainStatusType.ERROR.getId());
throw new RuntimeException(e);
}
}
}
private void writeJsonFile(Object data, Path outputPath) throws IOException {
Path parent = outputPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
try (OutputStream os =
Files.newOutputStream(
outputPath, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING)) {
objectMapper.writeValue(os, data);
}
}
private void zipFiles(List<Path> files, Path zipPath) throws IOException {
Path parent = zipPath.getParent();
if (parent != null) {
Files.createDirectories(parent);
}
try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(zipPath))) {
for (Path file : files) {
ZipEntry entry = new ZipEntry(file.getFileName().toString());
zos.putNextEntry(entry);
Files.copy(file, zos);
zos.closeEntry();
}
}
}
}

View File

@@ -0,0 +1,156 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class ModelTrainMetricsJobService {
private final ModelTrainMetricsJobCoreService modelTrainMetricsJobCoreService;
@Value("${spring.profiles.active}")
private String profile;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
/** 결과 csv 파일 정보 등록 */
public void findTrainValidMetricCsvFiles() {
List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
if (modelIds.isEmpty()) {
return;
}
for (ResponsePathDto modelInfo : modelIds) {
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
long iteration = Long.parseLong(record.get("Iteration"));
double Loss = Double.parseDouble(record.get("Loss"));
double LR = Double.parseDouble(record.get("LR"));
float time = Float.parseFloat(record.get("Time"));
batchArgs.add(new Object[] {modelInfo.getModelId(), epoch, iteration, Loss, LR, time});
}
modelTrainMetricsJobCoreService.insertModelMetricsTrain(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List<Object[]> batchArgs = new ArrayList<>();
for (CSVRecord record : parser) {
int epoch = Integer.parseInt(record.get("Epoch"));
float aAcc = Float.parseFloat(record.get("aAcc"));
float mFscore = Float.parseFloat(record.get("mFscore"));
float mPrecision = Float.parseFloat(record.get("mPrecision"));
float mRecall = Float.parseFloat(record.get("mRecall"));
float mIoU = Float.parseFloat(record.get("mIoU"));
float mAcc = Float.parseFloat(record.get("mAcc"));
float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
float changed_precision = Float.parseFloat(record.get("changed_precision"));
float changed_recall = Float.parseFloat(record.get("changed_recall"));
float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
float unchanged_precision = Float.parseFloat(record.get("unchanged_precision"));
float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
batchArgs.add(
new Object[] {
modelInfo.getModelId(),
epoch,
aAcc,
mFscore,
mPrecision,
mRecall,
mIoU,
mAcc,
changed_fscore,
changed_precision,
changed_recall,
unchanged_fscore,
unchanged_precision,
unchanged_recall
});
}
modelTrainMetricsJobCoreService.insertModelMetricsValidation(batchArgs);
} catch (IOException e) {
throw new RuntimeException(e);
}
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
Integer epoch = null;
boolean exists;
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
try (Stream<Path> s = Files.list(responsePath)) {
epoch =
s.filter(Files::isRegularFile)
.map(
p -> {
Matcher matcher = pattern.matcher(p.getFileName().toString());
if (matcher.matches()) {
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
}
return null;
})
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(), epoch);
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
modelInfo.getModelId(), "step1");
}
}
}

View File

@@ -0,0 +1,100 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TestJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @param uuid
* @param epoch
* @return
*/
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
// 마스터 확인
modelTrainMngCoreService.findModelById(modelId);
// 폴더 카운트
dataSetCounters.getCount(modelId);
// best epoch 업데이트
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
// 파라미터 조회
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL");
params.put("uuid", String.valueOf(uuid));
params.put("epoch", epoch);
params.put("datasetFolder", trainRunRequest.getDatasetFolder());
params.put("outputFolder", trainRunRequest.getOutputFolder());
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, params, ZonedDateTime.now());
// test training run 테이블에 적재하기
modelTrainJobCoreService.insertModelTestTrainingRun(modelId, jobId, epoch);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 취소
*
* @param modelId
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
}

View File

@@ -0,0 +1,236 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.io.IOException;
import java.nio.file.*;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class TmpDatasetService {
@Value("${train.docker.requestDir}")
private String requestDir;
@Value("${train.docker.basePath}")
private String trainBaseDir;
/**
* train, val, test 폴더별로 link
*
* @param uid 임시폴더 uuid
* @param type train, val, test
* @param links tif pull path
* @return
* @throws IOException
*/
public void buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
throws IOException {
if (links == null || links.isEmpty()) {
throw new IOException("links is empty");
}
Path tmp = Path.of(trainBaseDir, "tmp", uid);
long hardlinksMade = 0;
for (ModelTrainLinkDto dto : links) {
if (type == null) {
log.warn("SKIP - trainType null: {}", dto);
continue;
}
// type별 디렉토리 생성
Files.createDirectories(tmp.resolve(type).resolve("input1"));
Files.createDirectories(tmp.resolve(type).resolve("input2"));
Files.createDirectories(tmp.resolve(type).resolve("label"));
Files.createDirectories(tmp.resolve(type).resolve("label-json"));
// comparePath → input1
hardlinksMade += link(tmp, type, "input1", dto.getComparePath());
// targetPath → input2
hardlinksMade += link(tmp, type, "input2", dto.getTargetPath());
// labelPath → label
hardlinksMade += link(tmp, type, "label", dto.getLabelPath());
// geoJsonPath -> label-json
hardlinksMade += link(tmp, type, "label-json", dto.getGeoJsonPath());
}
if (hardlinksMade == 0) {
throw new IOException("No hardlinks created.");
}
log.info("tmp dataset created: {}, hardlinksMade={}", tmp, hardlinksMade);
}
private long link(Path tmp, String type, String part, String fullPath) throws IOException {
if (fullPath == null || fullPath.isBlank()) return 0;
Path src = Path.of(fullPath);
if (!Files.isRegularFile(src)) {
log.warn("SKIP (not file): {}", src);
return 0;
}
String fileName = src.getFileName().toString();
Path dst = tmp.resolve(type).resolve(part).resolve(fileName);
// 충돌 시 덮어쓰기
if (Files.exists(dst)) {
Files.delete(dst);
}
Files.createLink(dst, src);
return 1;
}
private String safe(String s) {
return (s == null || s.isBlank()) ? null : s.trim();
}
/**
* request 전체 폴더 link
*
* @param uid
* @param datasetUids
* @return
* @throws IOException
*/
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
log.info("========== buildTmpDatasetHardlink START ==========");
log.info("uid={}", uid);
log.info("datasetUids={}", datasetUids);
log.info("requestDir(raw)={}", requestDir);
Path BASE = toPath(requestDir);
Path tmp = Path.of(trainBaseDir, "tmp", uid);
log.info("BASE={}", BASE);
log.info("BASE exists? {}", Files.isDirectory(BASE));
log.info("tmp={}", tmp);
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
// tmp 디렉토리 준비
for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label", "label-json")) {
Path dir = tmp.resolve(type).resolve(part);
Files.createDirectories(dir);
log.info("createDirectories: {}", dir);
}
}
// 하드링크는 "같은 파일시스템"에서만 가능하므로 BASE/tmp가 같은 FS인지 미리 확인(권장)
try {
var baseStore = Files.getFileStore(BASE);
var tmpStore = Files.getFileStore(tmp.getParent()); // BASE/tmp
if (!baseStore.name().equals(tmpStore.name()) || !baseStore.type().equals(tmpStore.type())) {
throw new IOException(
"Hardlink requires same filesystem. baseStore="
+ baseStore.name()
+ "("
+ baseStore.type()
+ "), tmpStore="
+ tmpStore.name()
+ "("
+ tmpStore.type()
+ ")");
}
} catch (Exception e) {
// FileStore 비교가 환경마다 애매할 수 있어서, 여기서는 경고만 주고 실제 createLink에서 최종 판단하게 둘 수도 있음.
log.warn("FileStore check skipped/failed (will rely on createLink): {}", e.toString());
}
for (String id : datasetUids) {
Path srcRoot = BASE.resolve(id);
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label", "label-json")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
if (!Files.isDirectory(srcDir)) {
log.warn("SKIP (not directory): {}", srcDir);
noDir++;
continue;
}
scannedDirs++;
log.info("SCAN dir={}", srcDir);
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
for (Path f : stream) {
if (!Files.isRegularFile(f)) {
log.debug("skip non-regular file: {}", f);
continue;
}
regularFiles++;
String dstName = id + "__" + f.getFileName();
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
// dst가 남아있으면 삭제(심볼릭링크든 파일이든)
if (Files.exists(dst) || Files.isSymbolicLink(dst)) {
Files.delete(dst);
log.debug("deleted existing: {}", dst);
}
try {
// 하드링크 생성 (dst가 새 파일로 생기지만 inode는 f와 동일)
Files.createLink(dst, f);
hardlinksMade++;
log.debug("created hardlink: {} => {}", dst, f);
} catch (IOException e) {
// 여기서 바로 실패시키면 “tmp는 만들었는데 내용은 0개” 같은 상태를 방지할 수 있음
log.error("FAILED create hardlink: {} => {}", dst, f, e);
throw e;
}
}
}
}
}
}
if (hardlinksMade == 0) {
throw new IOException(
"No hardlinks created. regularFiles="
+ regularFiles
+ ", scannedDirs="
+ scannedDirs
+ ", noDir="
+ noDir);
}
log.info("tmp dataset created: {}", tmp);
log.info(
"summary: scannedDirs={}, noDir={}, regularFiles={}, hardlinksMade={}",
scannedDirs,
noDir,
regularFiles,
hardlinksMade);
return uid;
}
private static Path toPath(String p) {
if (p.startsWith("~/")) {
return Paths.get(System.getProperty("user.home")).resolve(p.substring(2)).normalize();
}
return Paths.get(p).toAbsolutePath().normalize();
}
}

View File

@@ -0,0 +1,306 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@Log4j2
@RequiredArgsConstructor
@Transactional(readOnly = true)
public class TrainJobService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher eventPublisher;
private final TmpDatasetService tmpDatasetService;
private final DataSetCountersService dataSetCounters;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
public Long getModelIdByUuid(UUID uuid) {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @return
*/
@Transactional
public Long enqueue(Long modelId) {
// 폴더 카운트
dataSetCounters.getCount(modelId);
// 마스터 존재 확인(없으면 예외)
modelTrainMngCoreService.findModelById(modelId);
// 파라미터 조회
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
if (trainRunRequest == null) {
throw new IllegalArgumentException("Model not found: " + modelId);
}
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
@SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
paramsMap.put("jobType", "TRAIN");
paramsMap.put("uuid", trainRunRequest.getUuid());
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, paramsMap, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
// 커밋 이후 Worker 실행 트리거(리스너에서 AFTER_COMMIT로 받아야 함)
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
/**
* 재시작
*
* <p>- STOPPED / ERROR 상태에서만 가능 (IN_PROGRESS면 예외) - 이전 params_json 재사용 - 새 attempt 생성
*/
@Transactional
public Long restart(Long modelId) {
return createNextAttempt(modelId, ResumeMode.NONE);
}
/**
* 이어하기
*
* @param modelId
* @return
*/
@Transactional
public Long resume(Long modelId) {
return createNextAttempt(modelId, ResumeMode.REQUIRE);
}
/**
* 중단
*
* <p>- job 상태 CANCELED - master 상태 STOPPED
*
* <p>※ 실제 docker stop은 Worker/Runner가 수행(운영 안정)
*/
@Transactional
public void cancel(Long modelId) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
Long jobId = master.getCurrentAttemptId();
if (jobId == null) {
throw new IllegalStateException("실행중인 작업이 없습니다.");
}
var job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalStateException("Job not found"));
String containerName = job.getContainerName();
// 1) 컨테이너 강제 종료 + 제거 (없거나 이미 죽었어도 괜찮게)
if (containerName != null && !containerName.isBlank()) {
dockerTrainService.killContainer(containerName);
}
// 2) 상태 업데이트 (항상 수행)
modelTrainJobCoreService.markCanceled(jobId);
modelTrainMngCoreService.markStopped(modelId);
}
/**
* 학습 이어하기
*
* @param modelId 모델 id
* @param mode NONE 새로 시작, REQUIRE 이어하기
* @return
*/
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
ModelTrainJobDto lastJob =
modelTrainJobCoreService
.findLatestByModelId(modelId)
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
// 이전 params_json 재사용 (재현성)
Map<String, Object> params = lastJob.getParamsJson();
if (params == null || params.isEmpty()) {
throw new IllegalStateException("이전 실행 params_json이 없습니다.");
}
// mode에 따라 resume 옵션 주입/제거
Map<String, Object> nextParams = new java.util.LinkedHashMap<>(params);
if (mode == ResumeMode.NONE) {
// 이어하기 관련 키가 있다면 제거 (완전 새로 시작 보장)
nextParams.remove("resumeFrom");
nextParams.remove("resume");
} else if (mode == ResumeMode.REQUIRE) {
// 체크포인트 탐지해서 resumeFrom 세팅
String resumeFrom = findResumeFromOrNull(nextParams);
if (resumeFrom == null) {
throw new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND, "이어하기 체크포인트가 없습니다.");
}
nextParams.put("resumeFrom", resumeFrom);
nextParams.put("resume", true);
}
Long jobId =
modelTrainJobCoreService.createQueuedJob(
modelId, nextAttemptNo, nextParams, ZonedDateTime.now());
modelTrainMngCoreService.clearLastError(modelId);
modelTrainMngCoreService.markInProgress(modelId, jobId);
eventPublisher.publishEvent(new ModelTrainJobQueuedEvent(jobId));
return jobId;
}
private enum ResumeMode {
NONE, // 새로 시작
REQUIRE // 이어하기
}
/**
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
*
* @param paramsJson
* @return
*/
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
Object out = paramsJson.get("outputFolder");
if (out == null) return null;
String outputFolder = String.valueOf(out).trim();
if (outputFolder.isEmpty()) return null;
Path outDir = Paths.get(responseDir, outputFolder);
log.info("resume outDir response path: {}", outDir);
Path last = outDir.resolve("last_checkpoint");
log.info("resume last response path: {}", last);
if (!Files.isRegularFile(last)) return null;
try {
// last_checkpoint 내용 그대로 읽기
String containerPath = Files.readString(last).trim();
log.info("resume containerPath: {}", containerPath);
if (containerPath.isEmpty()) return null;
// 호스트 경로로 변환해서 실제 파일 존재 확인
String hostPathStr = containerPath.replace("/checkpoints", responseDir);
Path hostPath = Paths.get(hostPathStr);
log.info("resume hostPath: {}", hostPath);
if (!Files.isRegularFile(hostPath)) return null;
// 3컨테이너 경로 그대로 반환
return containerPath;
} catch (Exception e) {
log.error("resume error", e);
return null;
}
}
/**
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
*
* @param modelUuid
* @return
*/
@Transactional
public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
// model id 가져오기
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
// model 에 연결된 dataset id 가져오기
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try {
// 데이터셋 심볼링크 생성
// String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
// train path
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
// validation path
List<ModelTrainLinkDto> valList = modelTrainMngCoreService.findDatasetValPath(modelId);
// test path
List<ModelTrainLinkDto> testList = modelTrainMngCoreService.findDatasetTestPath(modelId);
// train 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "train", trainList);
// val 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "val", valList);
// test 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "test", testList);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(raw);
// 학습모델을 수정한다.
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
log.error(
"createTmpFile failed. modelUuid={}, modelId={}, tmpRaw={}, datasetIdsSize={}, uidsSize={}",
modelUuid,
modelId,
raw,
(datasetIds == null ? null : datasetIds.size()),
(uids == null ? null : uids.size()),
e);
// 런타임 예외로 래핑하되, 메시지에 핵심 정보 포함
throw new CustomApiException(
"INTERNAL_SERVER_ERROR", HttpStatus.INTERNAL_SERVER_ERROR, "임시 데이터셋 생성에 실패했습니다.");
}
return modelUuid;
}
}

View File

@@ -0,0 +1,161 @@
package com.kamco.cd.training.train.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.EvalRunRequest;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.kamco.cd.training.train.dto.TrainRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
/** job 실행 */
@Log4j2
@Component
@RequiredArgsConstructor
public class TrainJobWorker {
private final ModelTrainJobCoreService modelTrainJobCoreService;
private final ModelTrainMngCoreService modelTrainMngCoreService;
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
private final ModelTestMetricsJobService modelTestMetricsJobService;
private final DockerTrainService dockerTrainService;
private final ObjectMapper objectMapper;
@Async("trainJobExecutor")
@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
public void handle(ModelTrainJobQueuedEvent event) {
log.info("[JOB] thread={}, jobId={}", Thread.currentThread().getName(), event.getJobId());
Long jobId = event.getJobId();
ModelTrainJobDto job =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
if (TrainStatusType.STOPPED.getId().equals(job.getStatusCd())) {
return;
}
Long modelId = job.getModelId();
Map<String, Object> params = job.getParamsJson();
String jobType = params != null ? String.valueOf(params.get("jobType")) : null;
boolean isEval = "EVAL".equals(jobType);
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
String type = isEval ? "TEST" : "TRAIN";
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
}
}
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
// 실행 시작 처리
modelTrainJobCoreService.markRunning(
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
log.info("[JOB] markRunning done jobId={}", jobId);
try {
TrainRunResult result;
if (isEval) {
// step2 진행중 처리
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
String datasetFolder = String.valueOf(params.get("datasetFolder"));
String outputFolder = String.valueOf(params.get("outputFolder"));
EvalRunRequest evalReq = new EvalRunRequest();
evalReq.setUuid(uuid);
evalReq.setEpoch(epoch);
evalReq.setTimeoutSeconds(null);
evalReq.setDatasetFolder(datasetFolder);
evalReq.setOutputFolder(outputFolder);
log.info("[JOB] selected test epoch={}", epoch);
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
} else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params);
// 도커 실행 후 로그 수집
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
ModelTrainJobDto latest =
modelTrainJobCoreService
.findById(jobId)
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + jobId));
if (TrainStatusType.STOPPED.getId().equals(latest.getStatusCd())) {
return;
}
if (result.getExitCode() == 0) {
// 성공 처리
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) {
// step2 완료처리
modelTrainMngCoreService.markStep2Success(modelId);
// 결과 csv 파일 정보 등록
modelTestMetricsJobService.findTestValidMetricCsvFiles();
} else {
modelTrainMngCoreService.markStep1Success(modelId);
// 결과 csv 파일 정보 등록
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
}
} else {
String failMsg = result.getStatus() + "\n" + result.getLogs();
// 실패 처리
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
}
} catch (Exception e) {
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}
}
private TrainRunRequest toTrainRunRequest(Map<String, Object> paramsJson) {
if (paramsJson == null || paramsJson.isEmpty()) {
return null;
}
return objectMapper.convertValue(paramsJson, TrainRunRequest.class);
}
}

View File

@@ -199,9 +199,6 @@ public class UploadDto {
private String fileName;
public double getUploadRate() {
if (this.chunkTotalIndex == 0) {
return 0.0;
}
return (double) (this.chunkIndex + 1) / (this.chunkTotalIndex + 1) * 100.0;
}
}

View File

@@ -234,8 +234,8 @@ public class UploadService {
try {
FIleChecker.deleteFolder(tmpDir);
// 108 에서 86 서버로 이동
log.info("################# server move 108 -> 86");
FIleChecker.uploadTo86(outputPath);
// log.info("################# server move 108 -> 86");
// FIleChecker.uploadTo86(outputPath);
} catch (Exception e) {
log.warn("tmpDir delete failed (merge already succeeded): tmpDir={}", tmpDir, e);
}

Some files were not shown because too many files have changed in this diff Show More