22 Commits

Author SHA1 Message Date
d7766edd24 Merge pull request 'return 형식 수정' (#65) from feat/training_260202 into develop
Reviewed-on: #65
2026-02-12 18:59:37 +09:00
0bc4453c9c hyperparam_with_modeltype 2026-02-12 18:56:32 +09:00
ae0d30e5da return 형식 수정 2026-02-12 18:55:42 +09:00
37d776dd2c Merge pull request 'hyperparam_with_modeltype' (#64) from feat/dean/hyperparam_with_modelType into develop
Reviewed-on: #64
2026-02-12 18:50:32 +09:00
3106d36431 Merge pull request '업로드 시 같은 uid로 업로드하지 못하게 조건 추가' (#63) from feat/training_260202 into develop
Reviewed-on: #63
2026-02-12 18:44:49 +09:00
ed48f697a4 업로드 시 같은 uid로 업로드하지 못하게 조건 추가 2026-02-12 18:44:04 +09:00
da92b28d97 Merge pull request '임시파일생성 소프트링크에서 하드링크로 변경' (#62) from feat/training_260202 into develop
Reviewed-on: #62
2026-02-12 18:20:30 +09:00
6c865d26fd 임시파일생성 소프트링크에서 하드링크로 변경 2026-02-12 18:18:44 +09:00
e3f00876f1 Merge pull request '문제되는 하이퍼파라미터 주석처리' (#61) from feat/training_260202 into develop
Reviewed-on: #61
2026-02-12 17:53:11 +09:00
16e156b5b4 문제되는 하이퍼파라미터 주석처리 2026-02-12 17:52:42 +09:00
60962bbc75 Merge pull request '학습실행 mount 경로 수정' (#60) from feat/training_260202 into develop
Reviewed-on: #60
2026-02-12 17:44:15 +09:00
6a939118ff 임시폴더생성 api 추가 2026-02-12 17:43:41 +09:00
64d37dcc08 Merge pull request '임시폴더생성 api 추가' (#59) from feat/training_260202 into develop
Reviewed-on: #59
2026-02-12 17:23:53 +09:00
0c0ae16c2b 임시폴더생성 api 추가 2026-02-12 17:23:34 +09:00
a2490f30e6 Merge pull request '임시폴더생성 api 수정' (#58) from feat/training_260202 into develop
Reviewed-on: #58
2026-02-12 17:14:52 +09:00
953f95aed6 임시폴더생성 api 추가 2026-02-12 17:14:26 +09:00
bd04e1f4e8 Merge pull request '임시폴더생성 api 추가' (#57) from feat/training_260202 into develop
Reviewed-on: #57
2026-02-12 17:03:39 +09:00
85633c8bab 임시폴더생성 api 추가 2026-02-12 17:03:21 +09:00
5fc15937c0 Merge pull request 'feat/training_260202' (#56) from feat/training_260202 into develop
Reviewed-on: #56
2026-02-12 17:00:08 +09:00
8b3940b446 Merge remote-tracking branch 'origin/feat/training_260202' into feat/training_260202 2026-02-12 16:59:44 +09:00
201cfefb6b 임시폴더생성 api 추가 2026-02-12 16:59:39 +09:00
9958b0999a csv 읽는 경로 수정하기, 변수명 수정 2026-02-12 16:58:28 +09:00
23 changed files with 254 additions and 153 deletions

View File

@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
private String desc;
public static ModelType getValueData(String modelNo) {
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
return Arrays.stream(ModelType.values())
.filter(m -> m.getId().equals(modelNo))
.findFirst()
.orElse(G1);
}
@Override

View File

@@ -217,7 +217,7 @@ public class DatasetApiController {
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
@RequestBody @Valid DatasetDto.AddReq addReq) {
return ApiResponseDto.ok(datasetService.insertDataset(addReq));
return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
}
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")

View File

@@ -208,6 +208,13 @@ public class DatasetService {
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
try {
// 같은 uid 로 등록한 파일이 있는지 확인
Long existsCnt =
datasetCoreService.findDatasetByUidExistsCnt(addReq.getFileName().replace(".zip", ""));
if (existsCnt > 0) {
return new ResponseObj(ApiResponseCode.DUPLICATE_DATA, "이미 등록된 회차 데이터 파일입니다. 확인 부탁드립니다.");
}
// 압축 해제
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());

View File

@@ -182,10 +182,8 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/init/{model}")
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
@PathVariable ModelType model
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
) {
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
}
}

View File

@@ -79,7 +79,8 @@ public class ModelTrainMngApiController {
@DeleteMapping("/{uuid}")
public ApiResponseDto<Void> deleteModelTrain(
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
@PathVariable UUID uuid) {
@PathVariable
UUID uuid) {
modelTrainMngService.deleteModelTrain(uuid);
return ApiResponseDto.ok(null);
}

View File

@@ -60,7 +60,7 @@ public class ModelTrainMngDto {
}
}
public String getStep2StatusNAme() {
public String getStep2StatusName() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()

View File

@@ -13,7 +13,6 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@@ -102,9 +101,9 @@ public class ModelTrainMngService {
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
String tmpUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
updateReq.setRequestPath(tmpUid);
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);

View File

@@ -7,7 +7,6 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Slf4j
@Service
@@ -17,50 +16,129 @@ public class TmpDatasetService {
@Value("${train.docker.requestDir}")
private String requestDir;
@Transactional(readOnly = true)
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
// 환경에 맞게 yml로 빼는 걸 추천
Path BASE = Paths.get(requestDir);
log.info("========== buildTmpDatasetHardlink START ==========");
log.info("uid={}", uid);
log.info("datasetUids={}", datasetUids);
log.info("requestDir(raw)={}", requestDir);
Path BASE = toPath(requestDir);
Path tmp = BASE.resolve("tmp").resolve(uid);
// mkdir -p "$TMP"/train/{input1,input2,label} ...
log.info("BASE={}", BASE);
log.info("BASE exists? {}", Files.isDirectory(BASE));
log.info("tmp={}", tmp);
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
// tmp 디렉토리 준비
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Files.createDirectories(tmp.resolve(type).resolve(part));
Path dir = tmp.resolve(type).resolve(part);
Files.createDirectories(dir);
log.info("createDirectories: {}", dir);
}
}
// 하드링크는 "같은 파일시스템"에서만 가능하므로 BASE/tmp가 같은 FS인지 미리 확인(권장)
try {
var baseStore = Files.getFileStore(BASE);
var tmpStore = Files.getFileStore(tmp.getParent()); // BASE/tmp
if (!baseStore.name().equals(tmpStore.name()) || !baseStore.type().equals(tmpStore.type())) {
throw new IOException(
"Hardlink requires same filesystem. baseStore="
+ baseStore.name()
+ "("
+ baseStore.type()
+ "), tmpStore="
+ tmpStore.name()
+ "("
+ tmpStore.type()
+ ")");
}
} catch (Exception e) {
// FileStore 비교가 환경마다 애매할 수 있어서, 여기서는 경고만 주고 실제 createLink에서 최종 판단하게 둘 수도 있음.
log.warn("FileStore check skipped/failed (will rely on createLink): {}", e.toString());
}
for (String id : datasetUids) {
Path srcRoot = BASE.resolve(id);
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
for (String type : List.of("train", "val")) {
for (String part : List.of("input1", "input2", "label")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
if (!Files.isDirectory(srcDir)) continue;
if (!Files.isDirectory(srcDir)) {
log.warn("SKIP (not directory): {}", srcDir);
noDir++;
continue;
}
scannedDirs++;
log.info("SCAN dir={}", srcDir);
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
for (Path f : stream) {
if (!Files.isRegularFile(f)) continue;
if (!Files.isRegularFile(f)) {
log.debug("skip non-regular file: {}", f);
continue;
}
regularFiles++;
String dstName = id + "__" + f.getFileName();
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
if (Files.exists(dst)) continue;
// dst가 남아있으면 삭제(심볼릭링크든 파일이든)
if (Files.exists(dst) || Files.isSymbolicLink(dst)) {
Files.delete(dst);
log.debug("deleted existing: {}", dst);
}
// ln -s "$f" "$dst" 와 동일
Files.createSymbolicLink(dst, f.toAbsolutePath());
try {
// 하드링크 생성 (dst가 새 파일로 생기지만 inode는 f와 동일)
Files.createLink(dst, f);
hardlinksMade++;
log.debug("created hardlink: {} => {}", dst, f);
} catch (IOException e) {
// 여기서 바로 실패시키면 “tmp는 만들었는데 내용은 0개” 같은 상태를 방지할 수 있음
log.error("FAILED create hardlink: {} => {}", dst, f, e);
throw e;
}
}
}
}
}
}
if (hardlinksMade == 0) {
throw new IOException(
"No hardlinks created. regularFiles="
+ regularFiles
+ ", scannedDirs="
+ scannedDirs
+ ", noDir="
+ noDir);
}
log.info("tmp dataset created: {}", tmp);
return tmp;
log.info(
"summary: scannedDirs={}, noDir={}, regularFiles={}, hardlinksMade={}",
scannedDirs,
noDir,
regularFiles,
hardlinksMade);
return uid;
}
private static Path toPath(String p) {
if (p.startsWith("~/")) {
return Paths.get(System.getProperty("user.home")).resolve(p.substring(2)).normalize();
}
return Paths.get(p).toAbsolutePath().normalize();
}
}

View File

@@ -246,4 +246,8 @@ public class DatasetCoreService
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
datasetObjRepository.insertDatasetValObj(objRegDto);
}
public Long findDatasetByUidExistsCnt(String uid) {
return datasetRepository.findDatasetByUidExistsCnt(uid);
}
}

View File

@@ -50,15 +50,15 @@ public class HyperParamCoreService {
/**
* 하이퍼파라미터 수정
*
* @param uuid uuid
* @param uuid uuid
* @param createReq 등록 요청
* @return ver
*/
public String updateHyperParam(UUID uuid, HyperParam createReq) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
return entity.getHyperVer();
}
/**
* 하이퍼파라미터 삭제
*
@@ -80,15 +79,15 @@ public class HyperParamCoreService {
*/
public void deleteHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
// if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// }
// if (entity.getHyperVer().equals("HPs_0001")) {
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
// }
//디폴트면 삭제불가
// 디폴트면 삭제불가
if (entity.getIsDefault()) {
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
}
@@ -105,9 +104,9 @@ public class HyperParamCoreService {
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository
.getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
hyperParamRepository
.getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
@@ -118,9 +117,9 @@ public class HyperParamCoreService {
*/
public HyperParamDto.Basic getHyperParam(UUID uuid) {
ModelHyperParamEntity entity =
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
hyperParamRepository
.findHyperParamByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto();
}
@@ -143,16 +142,16 @@ public class HyperParamCoreService {
*/
public String getFirstHyperParamVersion(ModelType model) {
return hyperParamRepository
.findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model))
.orElse(model.name() + "_000001");
.findHyperParamVerByModelType(model)
.map(ModelHyperParamEntity::getHyperVer)
.map(ver -> increase(ver, model))
.orElse(model.name() + "_000001");
}
/**
* 하이퍼 파라미터의 버전을 증가시킨다.
*
* @param hyperVer 현재 버전
* @param hyperVer 현재 버전
* @param modelType 모델 타입
* @return 증가된 버전
*/
@@ -214,5 +213,4 @@ public class HyperParamCoreService {
// memo
entity.setMemo(src.getMemo());
}
}

View File

@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
*/
public void deleteModel(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
entity.setDelYn(true);
entity.setUpdatedDttm(ZonedDateTime.now());
entity.setUpdatedUid(userUtil.getId());
@@ -83,15 +83,15 @@ public class ModelTrainMngCoreService {
ModelMasterEntity entity = new ModelMasterEntity();
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
// 최적화 파라미터는 모델 type의 디폴트사용
// 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else {
hyperParamEntity =
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
}
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
@@ -99,8 +99,8 @@ public class ModelTrainMngCoreService {
}
String modelVer =
String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
String.join(
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
entity.setModelVer(modelVer);
entity.setHyperParamId(hyperParamEntity.getId());
entity.setModelNo(addReq.getModelNo());
@@ -132,7 +132,7 @@ public class ModelTrainMngCoreService {
* data set 저장
*
* @param modelId 저장한 모델 학습 id
* @param addReq 요청 파라미터
* @param addReq 요청 파라미터
*/
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
TrainingDataset dataset = addReq.getTrainingDataset();
@@ -165,9 +165,9 @@ public class ModelTrainMngCoreService {
*/
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
entity.setRequestPath(req.getRequestPath());
}
@@ -180,7 +180,7 @@ public class ModelTrainMngCoreService {
/**
* 모델 데이터셋 mapping 테이블 저장
*
* @param modelId 모델학습 id
* @param modelId 모델학습 id
* @param datasetList 선택한 data set
*/
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
@@ -197,7 +197,7 @@ public class ModelTrainMngCoreService {
* 모델학습 config 저장
*
* @param modelId 모델학습 id
* @param req 요청 파라미터
* @param req 요청 파라미터
* @return
*/
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
@@ -217,7 +217,7 @@ public class ModelTrainMngCoreService {
/**
* 데이터셋 매핑 생성
*
* @param modelUid 모델 UID
* @param modelUid 모델 UID
* @param datasetIds 데이터셋 ID 목록
*/
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
@@ -239,8 +239,8 @@ public class ModelTrainMngCoreService {
public ModelMasterEntity findByUuid(UUID uuid) {
try {
return modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} catch (IllegalArgumentException e) {
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
}
@@ -254,9 +254,9 @@ public class ModelTrainMngCoreService {
*/
public Long findModelIdByUuid(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.getId();
}
@@ -269,8 +269,8 @@ public class ModelTrainMngCoreService {
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
.findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
/**
@@ -301,21 +301,19 @@ public class ModelTrainMngCoreService {
*/
public ModelTrainMngDto.Basic findModelById(Long id) {
ModelMasterEntity entity =
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
modelMngRepository
.findById(id)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
return entity.toDto();
}
/**
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
*/
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
@Transactional
public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
master.setCurrentAttemptId(jobId);
@@ -323,54 +321,46 @@ public class ModelTrainMngCoreService {
// 필요하면 시작시간도 여기서 찍어줌
}
/**
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
*/
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
@Transactional
public void clearLastError(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setLastError(null);
}
/**
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
*/
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
@Transactional
public void markStopped(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.STOPPED.getId());
}
/**
* 완료 처리(옵션) - Worker가 성공 시 호출
*/
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
@Transactional
public void markCompleted(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*/
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep1State(TrainStatusType.ERROR.getId());
@@ -379,15 +369,13 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now());
}
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*/
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
master.setStatusCd(TrainStatusType.ERROR.getId());
master.setStep2State(TrainStatusType.ERROR.getId());
@@ -399,9 +387,9 @@ public class ModelTrainMngCoreService {
@Transactional
public void markSuccess(Long modelId) {
ModelMasterEntity master =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
// 모델 상태 완료 처리
master.setStatusCd(TrainStatusType.COMPLETED.getId());
@@ -429,9 +417,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep1StrtDttm(ZonedDateTime.now());
@@ -449,9 +437,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2InProgress(Long modelId, Long jobId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
entity.setStep2StrtDttm(ZonedDateTime.now());
@@ -469,9 +457,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep1Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep1State(TrainStatusType.COMPLETED.getId());
@@ -488,9 +476,9 @@ public class ModelTrainMngCoreService {
@Transactional(propagation = Propagation.REQUIRES_NEW)
public void markStep2Success(Long modelId) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
entity.setStep2State(TrainStatusType.COMPLETED.getId());
@@ -501,9 +489,9 @@ public class ModelTrainMngCoreService {
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
ModelMasterEntity entity =
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
modelMngRepository
.findById(modelId)
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
entity.setBestEpoch(epoch);
}

View File

@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
@Enumerated(EnumType.STRING)
private ModelType modelType;
@Column(name = "default_param")
private Boolean isDefault = false;
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
// -------------------------
this.gpuCnt,
this.gpuIds,
this.masterPort
, this.isDefault
);
this.masterPort,
this.isDefault);
}
}

View File

@@ -24,4 +24,6 @@ public interface DatasetRepositoryCustom {
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
List<String> findDatasetUid(List<Long> datasetIds);
Long findDatasetByUidExistsCnt(String uid);
}

View File

@@ -247,4 +247,13 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
public List<String> findDatasetUid(List<Long> datasetIds) {
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
}
@Override
public Long findDatasetByUidExistsCnt(String uid) {
return queryFactory
.select(dataset.id.count())
.from(dataset)
.where(dataset.uid.eq(uid))
.fetchOne();
}
}

View File

@@ -185,10 +185,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
@Override
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return Optional.ofNullable(
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne());
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(
modelHyperParamEntity
.delYn
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne());
}
}

View File

@@ -16,10 +16,10 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
@Override
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
queryFactory
return queryFactory
.select(modelDatasetMappEntity)
.from(modelDatasetMappEntity)
.where(modelDatasetMappEntity.modelUid.eq(modelId));
return List.of();
.where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch();
}
}

View File

@@ -42,7 +42,10 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step2EndDttm.isNotNull(),

View File

@@ -29,7 +29,10 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
return queryFactory
.select(
Projections.constructor(
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
ResponsePathDto.class,
modelMasterEntity.id,
modelMasterEntity.responsePath,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity.step1EndDttm.isNotNull(),

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.train.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
@@ -17,5 +18,6 @@ public class ModelTrainMetricsDto {
private Long modelId;
private String responsePath;
private UUID uuid;
}
}

View File

@@ -169,7 +169,7 @@ public class DockerTrainService {
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName + "-" + req.getUuid().substring(0, 8));
c.add(containerName);
// 실행 종료 시 자동 삭제
c.add("--rm");
@@ -206,7 +206,7 @@ public class DockerTrainService {
// 요청/결과 디렉토리 볼륨 마운트
c.add("-v");
c.add(requestDir + ":/data");
c.add(requestDir + "/tmp:/data");
c.add("-v");
c.add(responseDir + ":/checkpoints");
@@ -264,12 +264,12 @@ public class DockerTrainService {
// ===== Augmentation =====
addArg(c, "--rot-prob", req.getRotProb());
addArg(c, "--rot-degree", req.getRotDegree());
// addArg(c, "--rot-degree", req.getRotDegree()); // TODO AI 수정되면 주석 해제
addArg(c, "--flip-prob", req.getFlipProb());
addArg(c, "--exchange-prob", req.getExchangeProb());
addArg(c, "--brightness-delta", req.getBrightnessDelta());
addArg(c, "--contrast-range", req.getContrastRange());
addArg(c, "--saturation-range", req.getSaturationRange());
// addArg(c, "--contrast-range", req.getContrastRange()); // TODO AI 수정되면 주석 해제
// addArg(c, "--saturation-range", req.getSaturationRange()); // TODO AI 수정되면 주석 해제
addArg(c, "--hue-delta", req.getHueDelta());
addArg(c, "--resume-from", req.getResumeFrom());
@@ -377,7 +377,7 @@ public class DockerTrainService {
c.add("docker");
c.add("run");
c.add("--name");
c.add(containerName + "=" + req.getUuid().substring(0, 8));
c.add(containerName);
c.add("--rm");
c.add("--gpus");

View File

@@ -27,6 +27,10 @@ public class ModelTestMetricsJobService {
@Value("${spring.profiles.active}")
private String profile;
// 학습 결과가 저장될 호스트 디렉토리
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 실행중인 profile
*
@@ -51,7 +55,7 @@ public class ModelTestMetricsJobService {
for (ResponsePathDto modelInfo : modelIds) {
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {

View File

@@ -55,7 +55,7 @@ public class ModelTrainMetricsJobService {
for (ResponsePathDto modelInfo : modelIds) {
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
@@ -80,7 +80,7 @@ public class ModelTrainMetricsJobService {
throw new RuntimeException(e);
}
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
try (BufferedReader reader =
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {

View File

@@ -226,9 +226,9 @@ public class TrainJobService {
try {
// 데이터셋 심볼링크 생성
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(path.toString());
updateReq.setRequestPath(pathUid);
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {
throw new RuntimeException(e);