학습 진행 현황 - 진행율, 진행중인 모델 API 추가
This commit is contained in:
@@ -236,4 +236,41 @@ public class ModelTrainMngApiController {
|
||||
public ApiResponseDto<MonitorDto> getSystem() throws IOException {
|
||||
return ApiResponseDto.ok(systemMonitorService.get());
|
||||
}
|
||||
|
||||
@Operation(summary = "모델학습 1단계/2단계 실행중인 것 id 정보", description = "모델학습 1단계/2단계 실행중인 것 id 정보")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "검색 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = Long.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/ing-model")
|
||||
public ApiResponseDto<ModelTrainMngDto.InProgressModel> findInprogressModel() {
|
||||
return ApiResponseDto.ok(modelTrainMngService.findInprogressModel());
|
||||
}
|
||||
|
||||
@Operation(summary = "모델학습 진행율 퍼센트", description = "모델학습 진행율 퍼센트")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "검색 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = Long.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/progress-percent/{uuid}")
|
||||
public ApiResponseDto<ModelTrainMngDto.ProgressPercent> findTrainProgressPercent(
|
||||
@PathVariable UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainMngService.findTrainProgressPercent(uuid));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,4 +378,26 @@ public class ModelTrainMngDto {
|
||||
// 삭제 될 파일
|
||||
private List<String> deleteTargets;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
public static class InProgressModel {
|
||||
|
||||
private String modelNo;
|
||||
private UUID uuid;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
public static class ProgressPercent {
|
||||
|
||||
private Long modelId;
|
||||
private String jobType;
|
||||
private String statusCd;
|
||||
private Integer totalEpoch;
|
||||
private Integer currentEpoch;
|
||||
private Double percent;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,4 +350,12 @@ public class ModelTrainMngService {
|
||||
public Long findModelStep1InProgressCnt() {
|
||||
return modelTrainMngCoreService.findModelStep1InProgressCnt();
|
||||
}
|
||||
|
||||
public ModelTrainMngDto.InProgressModel findInprogressModel() {
|
||||
return modelTrainMngCoreService.findInprogressModel();
|
||||
}
|
||||
|
||||
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(UUID uuid) {
|
||||
return modelTrainMngCoreService.findTrainProgressPercent(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,4 +713,17 @@ public class ModelTrainMngCoreService {
|
||||
entity.setTmpFileStatus("FAIL");
|
||||
entity.setTmpFileErrMessage(message);
|
||||
}
|
||||
|
||||
public ModelTrainMngDto.InProgressModel findInprogressModel() {
|
||||
return modelMngRepository.findInprogressModel();
|
||||
}
|
||||
|
||||
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
return modelMngRepository.findTrainProgressPercent(entity.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,4 +47,13 @@ public interface ModelMngRepositoryCustom {
|
||||
* @return 모델 목록
|
||||
*/
|
||||
List<ModelMasterEntity> findByHyperParamId(Long hyperParamId);
|
||||
|
||||
/**
|
||||
* 학습 진행중인 모델 type, uuid 조회
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ModelTrainMngDto.InProgressModel findInprogressModel();
|
||||
|
||||
ModelTrainMngDto.ProgressPercent findTrainProgressPercent(Long id);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import static com.kamco.cd.training.postgres.entity.QMemberEntity.memberEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelHyperParamEntity.modelHyperParamEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
@@ -14,7 +15,9 @@ import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Expression;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.CaseBuilder;
|
||||
import com.querydsl.core.types.dsl.Expressions;
|
||||
import com.querydsl.core.types.dsl.NumberExpression;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -231,4 +234,55 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
.orderBy(modelMasterEntity.createdDttm.desc())
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelTrainMngDto.InProgressModel findInprogressModel() {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainMngDto.InProgressModel.class,
|
||||
modelMasterEntity.modelNo,
|
||||
modelMasterEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity
|
||||
.step1State
|
||||
.eq(TrainStatusType.IN_PROGRESS.getId())
|
||||
.or(modelMasterEntity.step2State.eq(TrainStatusType.IN_PROGRESS.getId())))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelTrainMngDto.ProgressPercent findTrainProgressPercent(Long id) {
|
||||
NumberExpression<Integer> currentEpoch =
|
||||
new CaseBuilder()
|
||||
.when(
|
||||
modelTrainJobEntity
|
||||
.jobType
|
||||
.eq("TEST")
|
||||
.and(modelTrainJobEntity.statusCd.eq("SUCCESS")))
|
||||
.then(1)
|
||||
.otherwise(modelTrainJobEntity.currentEpoch.coalesce(0));
|
||||
|
||||
NumberExpression<Integer> totalEpoch = modelTrainJobEntity.totalEpoch.coalesce(1);
|
||||
|
||||
// per 계산
|
||||
NumberExpression<Double> per =
|
||||
currentEpoch.divide(totalEpoch).multiply(100).castToNum(Double.class);
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainMngDto.ProgressPercent.class,
|
||||
modelTrainJobEntity.id,
|
||||
modelTrainJobEntity.jobType,
|
||||
modelTrainJobEntity.statusCd,
|
||||
totalEpoch.as("totalEpoch"),
|
||||
currentEpoch.as("currentEpoch"),
|
||||
per.as("per")))
|
||||
.from(modelTrainJobEntity)
|
||||
.where(modelTrainJobEntity.id.eq(id))
|
||||
.orderBy(modelTrainJobEntity.attemptNo.desc())
|
||||
.fetchFirst();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user