하이퍼파라미터 , 모델관리 수정

This commit is contained in:
2026-02-03 18:24:49 +09:00
parent d66711e4f4
commit 6e99c209d6
18 changed files with 786 additions and 481 deletions

View File

@@ -1,10 +1,13 @@
package com.kamco.cd.training.model.dto;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.enums.TrainType;
import com.kamco.cd.training.common.utils.interfaces.JsonFormatDttm;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;
@@ -31,14 +34,72 @@ public class ModelMngDto {
private UUID uuid;
private String modelVer;
@JsonFormatDttm private ZonedDateTime startDttm;
@JsonFormatDttm private ZonedDateTime step1StrtDttm;
@JsonFormatDttm private ZonedDateTime step1EndDttm;
@JsonFormatDttm private ZonedDateTime step1Duration;
@JsonFormatDttm private ZonedDateTime step2StrtDttm;
@JsonFormatDttm private ZonedDateTime step2EndDttm;
@JsonFormatDttm private ZonedDateTime step2Duration;
private String step1Status;
private String step2Status;
private String transferStatus;
private String statusCd;
private String trainType;
public String getStatusName() {
if (this.statusCd == null || this.statusCd.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.statusCd).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.statusCd; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep1StatusName() {
if (this.step1Status == null || this.step1Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step1Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step1Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getStep2StatusNAme() {
if (this.step2Status == null || this.step2Status.isBlank()) return null;
try {
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.step2Status; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
public String getTrainTypeName() {
if (this.trainType == null || this.trainType.isBlank()) return null;
try {
return TrainType.valueOf(this.trainType).getText(); // 또는 getName()
} catch (IllegalArgumentException e) {
return this.trainType; // 매핑 못하면 코드 그대로 반환(원하면 null 처리)
}
}
private String formatDuration(ZonedDateTime start, ZonedDateTime end) {
if (start == null || end == null) {
return null;
}
long totalSeconds = Math.abs(Duration.between(start, end).getSeconds());
long hours = totalSeconds / 3600;
long minutes = (totalSeconds % 3600) / 60;
long seconds = totalSeconds % 60;
return String.format("%d시간 %d분 %d초", hours, minutes, seconds);
}
public String getStep1Duration() {
return formatDuration(this.step1StrtDttm, this.step1EndDttm);
}
public String getStep2Duration() {
return formatDuration(this.step2StrtDttm, this.step2EndDttm);
}
}
@Schema(name = "searchReq", description = "모델 관리 목록조회 파라미터")

View File

@@ -29,7 +29,14 @@ public class ModelMngService {
return modelMngCoreService.findByModelList(searchReq);
}
public void deleteModelTrain(UUID uuid) {}
/**
* 학습모델 삭제
*
* @param uuid
*/
public void deleteModelTrain(UUID uuid) {
modelMngCoreService.deleteModel(uuid);
}
/**
* 모델 상세 조회
@@ -40,14 +47,4 @@ public class ModelMngService {
public ModelMngDto.Detail getModelDetail(Long modelUid) {
return modelMngCoreService.getModelDetail(modelUid);
}
/**
* 모델 상세 조회 (UUID 기반)
*
* @param uuid 모델 UUID
* @return 모델 상세 정보
*/
public ModelMngDto.Detail getModelDetailByUuid(String uuid) {
return modelMngCoreService.getModelDetailByUuid(uuid);
}
}

View File

@@ -4,10 +4,9 @@ import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.model.dto.ModelMngDto;
import com.kamco.cd.training.postgres.core.DatasetCoreService;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
import com.kamco.cd.training.postgres.core.ModelMngCoreService;
import com.kamco.cd.training.postgres.core.SystemMetricsCoreService;
import com.kamco.cd.training.postgres.entity.ModelTrainMasterEntity;
import com.kamco.cd.training.postgres.entity.ModelMasterEntity;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@@ -21,19 +20,9 @@ import org.springframework.transaction.annotation.Transactional;
public class ModelTrainService {
private final ModelMngCoreService modelMngCoreService;
private final HyperParamCoreService hyperParamCoreService;
private final DatasetCoreService datasetCoreService;
private final SystemMetricsCoreService systemMetricsCoreService;
/**
* 학습 모델 목록 조회
*
* @return 학습 모델 목록
*/
public List<ModelMngDto.TrainListRes> getTrainModelList() {
return modelMngCoreService.findAllTrainModels();
}
/**
* 학습 설정 통합 조회
*
@@ -99,7 +88,7 @@ public class ModelTrainService {
}
// 5. 학습 마스터 생성
ModelTrainMasterEntity entity = modelMngCoreService.createTrainMaster(trainReq);
ModelMasterEntity entity = modelMngCoreService.createTrainMaster(trainReq);
// 5. 데이터셋 매핑 생성
modelMngCoreService.createDatasetMappings(entity.getId(), trainReq.getDatasetIds());
@@ -178,39 +167,28 @@ public class ModelTrainService {
}
}
/**
* 학습 모델 삭제
*
* @param uuid 모델 UUID
*/
@Transactional
public void deleteTrainModel(String uuid) {
modelMngCoreService.deleteByUuid(uuid);
log.info("학습 모델 삭제 완료: uuid={}", uuid);
}
// ==================== Resume Training (학습 재시작) ====================
/**
* 학습 재시작 정보 조회
*
* @param uuid 모델 UUID
* @return 재시작 정보
*/
public ModelMngDto.ResumeInfo getResumeInfo(String uuid) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
return ModelMngDto.ResumeInfo.builder()
.canResume(entity.getCanResume() != null && entity.getCanResume())
.lastEpoch(entity.getLastCheckpointEpoch())
.totalEpoch(entity.getEpochCnt())
.checkpointPath(entity.getCheckpointPath())
// .failedAt(
// entity.getStopDttm() != null
// ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
// : null)
.build();
}
//
// /**
// * 학습 재시작 정보 조회
// *
// * @param uuid 모델 UUID
// * @return 재시작 정보
// */
// public ModelMngDto.ResumeInfo getResumeInfo(String uuid) {
// ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
//
// return ModelMngDto.ResumeInfo.builder()
// .canResume(entity.getCanResume() != null && entity.getCanResume())
// .lastEpoch(entity.getLastCheckpointEpoch())
// .totalEpoch(entity.getEpochCnt())
// .checkpointPath(entity.getCheckpointPath())
// // .failedAt(
// // entity.getStopDttm() != null
// // ? entity.getStopDttm().atZone(java.time.ZoneId.systemDefault())
// // : null)
// .build();
// }
/**
* 학습 재시작
@@ -222,41 +200,42 @@ public class ModelTrainService {
@Transactional
public ModelMngDto.ResumeResponse resumeTraining(
String uuid, ModelMngDto.ResumeRequest resumeReq) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
//
// // 재시작 가능 여부 검증
// if (entity.getCanResume() == null || !entity.getCanResume()) {
// throw new IllegalStateException("학습 재시작이 불가능한 모델입니다: " + uuid);
// }
//
// if (entity.getLastCheckpointEpoch() == null) {
// throw new IllegalStateException("Checkpoint가 존재하지 않습니다: " + uuid);
// }
//
// // 상태 업데이트
// entity.setStatusCd("RUNNING");
// entity.setProgressRate(0);
//
// // 총 Epoch 수 변경 (선택사항)
// if (resumeReq.getNewTotalEpoch() != null) {
// entity.setEpochCnt(resumeReq.getNewTotalEpoch());
// }
//
// log.info(
// "학습 재시작: uuid={}, resumeFromEpoch={}, totalEpoch={}",
// uuid,
// resumeReq.getResumeFromEpoch(),
// entity.getEpochCnt());
//
// // TODO: 비동기 GPU 학습 재시작 프로세스 트리거 로직 추가
// // - Checkpoint 파일 로드
// // - 지정된 Epoch부터 학습 재개
// 재시작 가능 여부 검증
if (entity.getCanResume() == null || !entity.getCanResume()) {
throw new IllegalStateException("학습 재시작이 불가능한 모델입니다: " + uuid);
}
if (entity.getLastCheckpointEpoch() == null) {
throw new IllegalStateException("Checkpoint가 존재하지 않습니다: " + uuid);
}
// 상태 업데이트
entity.setStatusCd("RUNNING");
entity.setProgressRate(0);
// 총 Epoch 수 변경 (선택사항)
if (resumeReq.getNewTotalEpoch() != null) {
entity.setEpochCnt(resumeReq.getNewTotalEpoch());
}
log.info(
"학습 재시작: uuid={}, resumeFromEpoch={}, totalEpoch={}",
uuid,
resumeReq.getResumeFromEpoch(),
entity.getEpochCnt());
// TODO: 비동기 GPU 학습 재시작 프로세스 트리거 로직 추가
// - Checkpoint 파일 로드
// - 지정된 Epoch부터 학습 재개
return ModelMngDto.ResumeResponse.builder()
.uuid(uuid)
.status(entity.getStatusCd())
.resumedFromEpoch(resumeReq.getResumeFromEpoch())
.build();
return null;
// ModelMngDto.ResumeResponse.builder()
// .uuid(uuid)
// .status(entity.getStatusCd())
// .resumedFromEpoch(resumeReq.getResumeFromEpoch())
// .build();
}
// ==================== Best Epoch Setting (Best Epoch 설정) ====================
@@ -271,47 +250,49 @@ public class ModelTrainService {
@Transactional
public ModelMngDto.BestEpochResponse setBestEpoch(
String uuid, ModelMngDto.BestEpochRequest bestEpochReq) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// 1차 학습 완료 상태 검증
if (!"STEP1_COMPLETED".equals(entity.getStatusCd())
&& !"STEP1".equals(entity.getProcessStep())) {
log.warn(
"Best Epoch 설정 시도: 현재 상태={}, processStep={}",
entity.getStatusCd(),
entity.getProcessStep());
}
Integer previousBestEpoch = entity.getConfirmedBestEpoch();
// 사용자가 확정한 Best Epoch 설정
entity.setConfirmedBestEpoch(bestEpochReq.getBestEpoch());
// 2차 학습(Test) 단계로 상태 전이
entity.setProcessStep("STEP2");
entity.setStatusCd("STEP2_RUNNING");
entity.setProgressRate(0);
entity.setUpdatedDttm(java.time.ZonedDateTime.now());
log.info(
"Best Epoch 설정 및 2차 학습 시작: uuid={}, newBestEpoch={}, previousBestEpoch={}, reason={}, newStatus={}",
uuid,
bestEpochReq.getBestEpoch(),
previousBestEpoch,
bestEpochReq.getReason(),
entity.getStatusCd());
// ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
//
// // 1차 학습 완료 상태 검증
// if (!"STEP1_COMPLETED".equals(entity.getStatusCd())
// && !"STEP1".equals(entity.getProcessStep())) {
// log.warn(
// "Best Epoch 설정 시도: 현재 상태={}, processStep={}",
// entity.getStatusCd(),
// entity.getProcessStep());
// }
//
// Integer previousBestEpoch = entity.getConfirmedBestEpoch();
//
// // 사용자가 확정한 Best Epoch 설정
// entity.setConfirmedBestEpoch(bestEpochReq.getBestEpoch());
//
// // 2차 학습(Test) 단계로 상태 전이
// entity.setProcessStep("STEP2");
// entity.setStatusCd("STEP2_RUNNING");
// entity.setProgressRate(0);
// entity.setUpdatedDttm(java.time.ZonedDateTime.now());
//
// log.info(
// "Best Epoch 설정 및 2차 학습 시작: uuid={}, newBestEpoch={}, previousBestEpoch={}, reason={},
// newStatus={}",
// uuid,
// bestEpochReq.getBestEpoch(),
// previousBestEpoch,
// bestEpochReq.getReason(),
// entity.getStatusCd());
// TODO: 비동기 GPU 2차 학습(Test) 프로세스 트리거 로직 추가
// - Best Epoch 모델 로드
// - Test 데이터셋으로 성능 평가 실행
// - 완료 시 STEP2_COMPLETED 상태로 전환
return ModelMngDto.BestEpochResponse.builder()
.uuid(uuid)
.bestEpoch(entity.getBestEpoch()) // 자동 선택된 값
.confirmedBestEpoch(entity.getConfirmedBestEpoch()) // 사용자 확정
.previousBestEpoch(previousBestEpoch)
.build();
return null;
// ModelMngDto.BestEpochResponse.builder()
// .uuid(uuid)
// .bestEpoch(entity.getBestEpoch()) // 자동 선택된
// .confirmedBestEpoch(entity.getConfirmedBestEpoch()) // 사용자 확정 값
// .previousBestEpoch(previousBestEpoch)
// .build();
}
/**
@@ -321,33 +302,33 @@ public class ModelTrainService {
* @return Epoch별 성능 지표 목록
*/
public List<ModelMngDto.EpochMetric> getEpochMetrics(String uuid) {
ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
// ModelTrainMasterEntity entity = modelMngCoreService.findByUuid(uuid);
//
// // TODO: 실제 학습 로그 파일이나 DB에서 Epoch별 성능 지표 조회
// // 현재는 샘플 데이터 반환
// List<ModelMngDto.EpochMetric> metrics = new java.util.ArrayList<>();
//
// if (entity.getEpochCnt() != null && entity.getBestEpoch() != null) {
// // 샘플 데이터 생성 (실제로는 학습 로그 파일 파싱 또는 별도 테이블 조회)
// for (int i = 1; i <= Math.min(entity.getEpochCnt(), 10); i++) {
// int epoch = entity.getBestEpoch() - 5 + i;
// if (epoch <= 0 || epoch > entity.getEpochCnt()) {
// continue;
// }
//
// metrics.add(
// ModelMngDto.EpochMetric.builder()
// .epoch(epoch)
// .mIoU(0.80 + (Math.random() * 0.15)) // 샘플 데이터
// .mFscore(0.85 + (Math.random() * 0.10)) // 샘플 데이터
// .loss(0.3 - (Math.random() * 0.15)) // 샘플 데이터
// .isBest(entity.getBestEpoch() != null && epoch == entity.getBestEpoch())
// .build());
// }
// }
//
// log.info("Epoch별 성능 지표 조회: uuid={}, metricsCount={}", uuid, metrics.size());
// TODO: 실제 학습 로그 파일이나 DB에서 Epoch별 성능 지표 조회
// 현재는 샘플 데이터 반환
List<ModelMngDto.EpochMetric> metrics = new java.util.ArrayList<>();
if (entity.getEpochCnt() != null && entity.getBestEpoch() != null) {
// 샘플 데이터 생성 (실제로는 학습 로그 파일 파싱 또는 별도 테이블 조회)
for (int i = 1; i <= Math.min(entity.getEpochCnt(), 10); i++) {
int epoch = entity.getBestEpoch() - 5 + i;
if (epoch <= 0 || epoch > entity.getEpochCnt()) {
continue;
}
metrics.add(
ModelMngDto.EpochMetric.builder()
.epoch(epoch)
.mIoU(0.80 + (Math.random() * 0.15)) // 샘플 데이터
.mFscore(0.85 + (Math.random() * 0.10)) // 샘플 데이터
.loss(0.3 - (Math.random() * 0.15)) // 샘플 데이터
.isBest(entity.getBestEpoch() != null && epoch == entity.getBestEpoch())
.build());
}
}
log.info("Epoch별 성능 지표 조회: uuid={}, metricsCount={}", uuid, metrics.size());
return metrics;
return null; // metrics;
}
}