학습 진행 현황 - 진행율, 진행중인 모델 API 추가

This commit is contained in:
2026-06-04 12:38:23 +09:00
parent b85f920f40
commit 5f4640ea60
6 changed files with 143 additions and 0 deletions

View File

@@ -236,4 +236,41 @@ public class ModelTrainMngApiController {
public ApiResponseDto<MonitorDto> getSystem() throws IOException {
return ApiResponseDto.ok(systemMonitorService.get());
}
@Operation(summary = "모델학습 1단계/2단계 실행중인 것 id 정보", description = "모델학습 1단계/2단계 실행중인 것 id 정보")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/ing-model")
public ApiResponseDto<ModelTrainMngDto.InProgressModel> findInprogressModel() {
return ApiResponseDto.ok(modelTrainMngService.findInprogressModel());
}
@Operation(summary = "모델학습 진행율 퍼센트", description = "모델학습 진행율 퍼센트")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = Long.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/progress-percent/{uuid}")
public ApiResponseDto<ModelTrainMngDto.ProgressPercent> findTrainProgressPercent(
@PathVariable UUID uuid) {
return ApiResponseDto.ok(modelTrainMngService.findTrainProgressPercent(uuid));
}
}

View File

@@ -378,4 +378,26 @@ public class ModelTrainMngDto {
// 삭제 될 파일
private List<String> deleteTargets;
}
@Getter
@Builder
@AllArgsConstructor
public static class InProgressModel {
private String modelNo;
private UUID uuid;
}
@Getter
@Builder
@AllArgsConstructor
public static class ProgressPercent {
private Long modelId;
private String jobType;
private String statusCd;
private Integer totalEpoch;
private Integer currentEpoch;
private Double percent;
}
}

View File

@@ -350,4 +350,12 @@ public class ModelTrainMngService {
public Long findModelStep1InProgressCnt() {
return modelTrainMngCoreService.findModelStep1InProgressCnt();
}
public ModelTrainMngDto.InProgressModel findInprogressModel() {
return modelTrainMngCoreService.findInprogressModel();
}
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(UUID uuid) {
return modelTrainMngCoreService.findTrainProgressPercent(uuid);
}
}

View File

@@ -713,4 +713,17 @@ public class ModelTrainMngCoreService {
entity.setTmpFileStatus("FAIL");
entity.setTmpFileErrMessage(message);
}
public ModelTrainMngDto.InProgressModel findInprogressModel() {
return modelMngRepository.findInprogressModel();
}
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(UUID uuid) {
ModelMasterEntity entity =
modelMngRepository
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return modelMngRepository.findTrainProgressPercent(entity.getId());
}
}

View File

@@ -47,4 +47,13 @@ public interface ModelMngRepositoryCustom {
* @return 모델 목록
*/
List<ModelMasterEntity> findByHyperParamId(Long hyperParamId);
/**
* 학습 진행중인 모델 type, uuid 조회
*
* @return
*/
ModelTrainMngDto.InProgressModel findInprogressModel();
ModelTrainMngDto.ProgressPercent findTrainProgressPercent(Long id);
}

View File

@@ -4,6 +4,7 @@ import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
@@ -14,7 +15,9 @@ import com.kamco.cd.training.train.dto.TrainRunRequest;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.CaseBuilder;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import java.util.Optional;
@@ -231,4 +234,55 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
.orderBy(modelMasterEntity.createdDttm.desc())
.fetch();
}
@Override
public ModelTrainMngDto.InProgressModel findInprogressModel() {
return queryFactory
.select(
Projections.constructor(
ModelTrainMngDto.InProgressModel.class,
modelMasterEntity.modelNo,
modelMasterEntity.uuid))
.from(modelMasterEntity)
.where(
modelMasterEntity
.step1State
.eq(TrainStatusType.IN_PROGRESS.getId())
.or(modelMasterEntity.step2State.eq(TrainStatusType.IN_PROGRESS.getId())))
.fetchOne();
}
@Override
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(Long id) {
NumberExpression<Integer> currentEpoch =
new CaseBuilder()
.when(
modelTrainJobEntity
.jobType
.eq("TEST")
.and(modelTrainJobEntity.statusCd.eq("SUCCESS")))
.then(1)
.otherwise(modelTrainJobEntity.currentEpoch.coalesce(0));
NumberExpression<Integer> totalEpoch = modelTrainJobEntity.totalEpoch.coalesce(1);
// per 계산
NumberExpression<Double> per =
currentEpoch.divide(totalEpoch).multiply(100).castToNum(Double.class);
return queryFactory
.select(
Projections.constructor(
ModelTrainMngDto.ProgressPercent.class,
modelTrainJobEntity.id,
modelTrainJobEntity.jobType,
modelTrainJobEntity.statusCd,
totalEpoch.as("totalEpoch"),
currentEpoch.as("currentEpoch"),
per.as("per")))
.from(modelTrainJobEntity)
.where(modelTrainJobEntity.id.eq(id))
.orderBy(modelTrainJobEntity.attemptNo.desc())
.fetchFirst();
}
}