hyperparam_with_modeltype
This commit is contained in:
@@ -88,9 +88,8 @@ public class HyperParamApiController {
|
|||||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@GetMapping("{model}/list")
|
@GetMapping("/list")
|
||||||
public ApiResponseDto<Page<List>> getHyperParam(
|
public ApiResponseDto<Page<List>> getHyperParam(
|
||||||
@PathVariable ModelType model,
|
|
||||||
@Parameter(
|
@Parameter(
|
||||||
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
||||||
example = "CREATE_DATE")
|
example = "CREATE_DATE")
|
||||||
@@ -102,7 +101,8 @@ public class HyperParamApiController {
|
|||||||
LocalDate endDate,
|
LocalDate endDate,
|
||||||
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
||||||
String hyperVer,
|
String hyperVer,
|
||||||
@Parameter(
|
@Parameter(description = "버전명", example = "G1,G2,G3") @RequestParam(required = false) ModelType model
|
||||||
|
, @Parameter(
|
||||||
description = "정렬",
|
description = "정렬",
|
||||||
example = "createdDttm desc",
|
example = "createdDttm desc",
|
||||||
schema =
|
schema =
|
||||||
|
|||||||
@@ -110,9 +110,11 @@ public class HyperParamDto {
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public static class List {
|
public static class List {
|
||||||
private UUID uuid;
|
private UUID uuid;
|
||||||
|
private ModelType model;
|
||||||
private String hyperVer;
|
private String hyperVer;
|
||||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||||
|
private String memo;
|
||||||
private Long m1UseCnt;
|
private Long m1UseCnt;
|
||||||
private Long m2UseCnt;
|
private Long m2UseCnt;
|
||||||
private Long m3UseCnt;
|
private Long m3UseCnt;
|
||||||
|
|||||||
@@ -106,7 +106,8 @@ public class HyperParamCoreService {
|
|||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.getHyperparamByType(model)
|
.getHyperparamByType(model)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.stream().filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
|
.findFirst().orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,10 @@ public class ModelTrainMngCoreService {
|
|||||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType)
|
||||||
|
.stream()
|
||||||
|
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
|
.findFirst().orElse(null);
|
||||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@@ -307,7 +310,9 @@ public class ModelTrainMngCoreService {
|
|||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
/**
|
||||||
|
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markInProgress(Long modelId, Long jobId) {
|
public void markInProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -321,7 +326,9 @@ public class ModelTrainMngCoreService {
|
|||||||
// 필요하면 시작시간도 여기서 찍어줌
|
// 필요하면 시작시간도 여기서 찍어줌
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
/**
|
||||||
|
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void clearLastError(Long modelId) {
|
public void clearLastError(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -332,7 +339,9 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setLastError(null);
|
master.setLastError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
/**
|
||||||
|
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStopped(Long modelId) {
|
public void markStopped(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -343,7 +352,9 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
/**
|
||||||
|
* 완료 처리(옵션) - Worker가 성공 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markCompleted(Long modelId) {
|
public void markCompleted(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -354,7 +365,9 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -369,7 +382,9 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.kamco.cd.training.common.enums.ModelType;
|
|||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import org.springframework.data.domain.Page;
|
import org.springframework.data.domain.Page;
|
||||||
@@ -32,5 +33,5 @@ public interface HyperParamRepositoryCustom {
|
|||||||
|
|
||||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||||
|
|
||||||
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,7 +91,9 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
Pageable pageable = req.toPageable();
|
Pageable pageable = req.toPageable();
|
||||||
|
|
||||||
BooleanBuilder builder = new BooleanBuilder();
|
BooleanBuilder builder = new BooleanBuilder();
|
||||||
|
if (model != null) {
|
||||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||||
|
}
|
||||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||||
|
|
||||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||||
@@ -129,6 +131,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
Projections.constructor(
|
Projections.constructor(
|
||||||
HyperParamDto.List.class,
|
HyperParamDto.List.class,
|
||||||
modelHyperParamEntity.uuid,
|
modelHyperParamEntity.uuid,
|
||||||
|
modelHyperParamEntity.modelType.as("model"),
|
||||||
modelHyperParamEntity.hyperVer,
|
modelHyperParamEntity.hyperVer,
|
||||||
modelHyperParamEntity.createdDttm,
|
modelHyperParamEntity.createdDttm,
|
||||||
modelHyperParamEntity.lastUsedDttm,
|
modelHyperParamEntity.lastUsedDttm,
|
||||||
@@ -149,14 +152,12 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
boolean asc = sortOrder.isAscending();
|
boolean asc = sortOrder.isAscending();
|
||||||
|
|
||||||
switch (property) {
|
switch (property) {
|
||||||
case "createdDttm" ->
|
case "createdDttm" -> query.orderBy(
|
||||||
query.orderBy(
|
|
||||||
asc
|
asc
|
||||||
? modelHyperParamEntity.createdDttm.asc()
|
? modelHyperParamEntity.createdDttm.asc()
|
||||||
: modelHyperParamEntity.createdDttm.desc());
|
: modelHyperParamEntity.createdDttm.desc());
|
||||||
|
|
||||||
case "lastUsedDttm" ->
|
case "lastUsedDttm" -> query.orderBy(
|
||||||
query.orderBy(
|
|
||||||
asc
|
asc
|
||||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||||
@@ -183,16 +184,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||||
return Optional.ofNullable(
|
return queryFactory
|
||||||
queryFactory
|
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(
|
.where(
|
||||||
modelHyperParamEntity
|
modelHyperParamEntity
|
||||||
.delYn
|
.delYn.isFalse()
|
||||||
.isFalse()
|
|
||||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||||
.fetchOne());
|
.fetch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user