hyperparam_with_modeltype

This commit is contained in:
2026-02-12 19:14:01 +09:00
parent 0bc4453c9c
commit d5b2b8ecec
6 changed files with 186 additions and 168 deletions

View File

@@ -88,9 +88,8 @@ public class HyperParamApiController {
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content), @ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content) @ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
}) })
@GetMapping("{model}/list") @GetMapping("/list")
public ApiResponseDto<Page<List>> getHyperParam( public ApiResponseDto<Page<List>> getHyperParam(
@PathVariable ModelType model,
@Parameter( @Parameter(
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)", description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
example = "CREATE_DATE") example = "CREATE_DATE")
@@ -102,7 +101,8 @@ public class HyperParamApiController {
LocalDate endDate, LocalDate endDate,
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false) @Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
String hyperVer, String hyperVer,
@Parameter( @Parameter(description = "버전명", example = "G1,G2,G3") @RequestParam(required = false) ModelType model
, @Parameter(
description = "정렬", description = "정렬",
example = "createdDttm desc", example = "createdDttm desc",
schema = schema =

View File

@@ -110,9 +110,11 @@ public class HyperParamDto {
@AllArgsConstructor @AllArgsConstructor
public static class List { public static class List {
private UUID uuid; private UUID uuid;
private ModelType model;
private String hyperVer; private String hyperVer;
@JsonFormatDttm private ZonedDateTime createDttm; @JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm; @JsonFormatDttm private ZonedDateTime lastUsedDttm;
private String memo;
private Long m1UseCnt; private Long m1UseCnt;
private Long m2UseCnt; private Long m2UseCnt;
private Long m3UseCnt; private Long m3UseCnt;

View File

@@ -106,7 +106,8 @@ public class HyperParamCoreService {
ModelHyperParamEntity entity = ModelHyperParamEntity entity =
hyperParamRepository hyperParamRepository
.getHyperparamByType(model) .getHyperparamByType(model)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND)); .stream().filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst().orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
return entity.toDto(); return entity.toDto();
} }

View File

@@ -86,7 +86,10 @@ public class ModelTrainMngCoreService {
// 최적화 파라미터는 모델 type의 디폴트사용 // 최적화 파라미터는 모델 type의 디폴트사용
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) { if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo()); ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null); hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType)
.stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst().orElse(null);
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null); // hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
} else { } else {
@@ -307,7 +310,9 @@ public class ModelTrainMngCoreService {
return entity.toDto(); return entity.toDto();
} }
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */ /**
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
*/
@Transactional @Transactional
public void markInProgress(Long modelId, Long jobId) { public void markInProgress(Long modelId, Long jobId) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -321,7 +326,9 @@ public class ModelTrainMngCoreService {
// 필요하면 시작시간도 여기서 찍어줌 // 필요하면 시작시간도 여기서 찍어줌
} }
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */ /**
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
*/
@Transactional @Transactional
public void clearLastError(Long modelId) { public void clearLastError(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -332,7 +339,9 @@ public class ModelTrainMngCoreService {
master.setLastError(null); master.setLastError(null);
} }
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */ /**
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
*/
@Transactional @Transactional
public void markStopped(Long modelId) { public void markStopped(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -343,7 +352,9 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.STOPPED.getId()); master.setStatusCd(TrainStatusType.STOPPED.getId());
} }
/** 완료 처리(옵션) - Worker가 성공 시 호출 */ /**
* 완료 처리(옵션) - Worker가 성공 시 호출
*/
@Transactional @Transactional
public void markCompleted(Long modelId) { public void markCompleted(Long modelId) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -354,7 +365,9 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId()); master.setStatusCd(TrainStatusType.COMPLETED.getId());
} }
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markError(Long modelId, String errorMessage) { public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =
@@ -369,7 +382,9 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now()); master.setUpdatedDttm(ZonedDateTime.now());
} }
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */ /**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*/
@Transactional @Transactional
public void markStep2Error(Long modelId, String errorMessage) { public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master = ModelMasterEntity master =

View File

@@ -4,6 +4,7 @@ import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto; import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq; import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity; import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -32,5 +33,5 @@ public interface HyperParamRepositoryCustom {
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req); Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType); List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
} }

View File

@@ -91,7 +91,9 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
Pageable pageable = req.toPageable(); Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model)); builder.and(modelHyperParamEntity.modelType.eq(model));
}
builder.and(modelHyperParamEntity.delYn.isFalse()); builder.and(modelHyperParamEntity.delYn.isFalse());
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) { if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
@@ -129,6 +131,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
Projections.constructor( Projections.constructor(
HyperParamDto.List.class, HyperParamDto.List.class,
modelHyperParamEntity.uuid, modelHyperParamEntity.uuid,
modelHyperParamEntity.modelType.as("model"),
modelHyperParamEntity.hyperVer, modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm, modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm, modelHyperParamEntity.lastUsedDttm,
@@ -149,14 +152,12 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
boolean asc = sortOrder.isAscending(); boolean asc = sortOrder.isAscending();
switch (property) { switch (property) {
case "createdDttm" -> case "createdDttm" -> query.orderBy(
query.orderBy(
asc asc
? modelHyperParamEntity.createdDttm.asc() ? modelHyperParamEntity.createdDttm.asc()
: modelHyperParamEntity.createdDttm.desc()); : modelHyperParamEntity.createdDttm.desc());
case "lastUsedDttm" -> case "lastUsedDttm" -> query.orderBy(
query.orderBy(
asc asc
? modelHyperParamEntity.lastUsedDttm.asc() ? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc()); : modelHyperParamEntity.lastUsedDttm.desc());
@@ -183,16 +184,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
} }
@Override @Override
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) { public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
return Optional.ofNullable( return queryFactory
queryFactory
.select(modelHyperParamEntity) .select(modelHyperParamEntity)
.from(modelHyperParamEntity) .from(modelHyperParamEntity)
.where( .where(
modelHyperParamEntity modelHyperParamEntity
.delYn .delYn.isFalse()
.isFalse()
.and(modelHyperParamEntity.modelType.eq(modelType))) .and(modelHyperParamEntity.modelType.eq(modelType)))
.fetchOne()); .fetch();
} }
} }