hyperparam_with_modeltype
This commit is contained in:
@@ -88,9 +88,8 @@ public class HyperParamApiController {
|
|||||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@GetMapping("{model}/list")
|
@GetMapping("/list")
|
||||||
public ApiResponseDto<Page<List>> getHyperParam(
|
public ApiResponseDto<Page<List>> getHyperParam(
|
||||||
@PathVariable ModelType model,
|
|
||||||
@Parameter(
|
@Parameter(
|
||||||
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
||||||
example = "CREATE_DATE")
|
example = "CREATE_DATE")
|
||||||
@@ -102,7 +101,8 @@ public class HyperParamApiController {
|
|||||||
LocalDate endDate,
|
LocalDate endDate,
|
||||||
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
||||||
String hyperVer,
|
String hyperVer,
|
||||||
@Parameter(
|
@Parameter(description = "버전명", example = "G1,G2,G3") @RequestParam(required = false) ModelType model
|
||||||
|
, @Parameter(
|
||||||
description = "정렬",
|
description = "정렬",
|
||||||
example = "createdDttm desc",
|
example = "createdDttm desc",
|
||||||
schema =
|
schema =
|
||||||
|
|||||||
@@ -110,9 +110,11 @@ public class HyperParamDto {
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public static class List {
|
public static class List {
|
||||||
private UUID uuid;
|
private UUID uuid;
|
||||||
|
private ModelType model;
|
||||||
private String hyperVer;
|
private String hyperVer;
|
||||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||||
|
private String memo;
|
||||||
private Long m1UseCnt;
|
private Long m1UseCnt;
|
||||||
private Long m2UseCnt;
|
private Long m2UseCnt;
|
||||||
private Long m3UseCnt;
|
private Long m3UseCnt;
|
||||||
|
|||||||
@@ -50,15 +50,15 @@ public class HyperParamCoreService {
|
|||||||
/**
|
/**
|
||||||
* 하이퍼파라미터 수정
|
* 하이퍼파라미터 수정
|
||||||
*
|
*
|
||||||
* @param uuid uuid
|
* @param uuid uuid
|
||||||
* @param createReq 등록 요청
|
* @param createReq 등록 요청
|
||||||
* @return ver
|
* @return ver
|
||||||
*/
|
*/
|
||||||
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
if (entity.getIsDefault()) {
|
if (entity.getIsDefault()) {
|
||||||
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
@@ -79,9 +79,9 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public void deleteHyperParam(UUID uuid) {
|
public void deleteHyperParam(UUID uuid) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
@@ -104,9 +104,10 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.getHyperparamByType(model)
|
.getHyperparamByType(model)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.stream().filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
|
.findFirst().orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,9 +118,9 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,16 +143,16 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public String getFirstHyperParamVersion(ModelType model) {
|
public String getFirstHyperParamVersion(ModelType model) {
|
||||||
return hyperParamRepository
|
return hyperParamRepository
|
||||||
.findHyperParamVerByModelType(model)
|
.findHyperParamVerByModelType(model)
|
||||||
.map(ModelHyperParamEntity::getHyperVer)
|
.map(ModelHyperParamEntity::getHyperVer)
|
||||||
.map(ver -> increase(ver, model))
|
.map(ver -> increase(ver, model))
|
||||||
.orElse(model.name() + "_000001");
|
.orElse(model.name() + "_000001");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 하이퍼 파라미터의 버전을 증가시킨다.
|
* 하이퍼 파라미터의 버전을 증가시킨다.
|
||||||
*
|
*
|
||||||
* @param hyperVer 현재 버전
|
* @param hyperVer 현재 버전
|
||||||
* @param modelType 모델 타입
|
* @param modelType 모델 타입
|
||||||
* @return 증가된 버전
|
* @return 증가된 버전
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public void deleteModel(UUID uuid) {
|
public void deleteModel(UUID uuid) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
entity.setDelYn(true);
|
entity.setDelYn(true);
|
||||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
entity.setUpdatedUid(userUtil.getId());
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
@@ -86,12 +86,15 @@ public class ModelTrainMngCoreService {
|
|||||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType)
|
||||||
|
.stream()
|
||||||
|
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||||
|
.findFirst().orElse(null);
|
||||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
hyperParamEntity =
|
hyperParamEntity =
|
||||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||||
@@ -99,8 +102,8 @@ public class ModelTrainMngCoreService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String modelVer =
|
String modelVer =
|
||||||
String.join(
|
String.join(
|
||||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||||
entity.setModelVer(modelVer);
|
entity.setModelVer(modelVer);
|
||||||
entity.setHyperParamId(hyperParamEntity.getId());
|
entity.setHyperParamId(hyperParamEntity.getId());
|
||||||
entity.setModelNo(addReq.getModelNo());
|
entity.setModelNo(addReq.getModelNo());
|
||||||
@@ -132,7 +135,7 @@ public class ModelTrainMngCoreService {
|
|||||||
* data set 저장
|
* data set 저장
|
||||||
*
|
*
|
||||||
* @param modelId 저장한 모델 학습 id
|
* @param modelId 저장한 모델 학습 id
|
||||||
* @param addReq 요청 파라미터
|
* @param addReq 요청 파라미터
|
||||||
*/
|
*/
|
||||||
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||||
TrainingDataset dataset = addReq.getTrainingDataset();
|
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||||
@@ -165,9 +168,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||||
entity.setRequestPath(req.getRequestPath());
|
entity.setRequestPath(req.getRequestPath());
|
||||||
}
|
}
|
||||||
@@ -180,7 +183,7 @@ public class ModelTrainMngCoreService {
|
|||||||
/**
|
/**
|
||||||
* 모델 데이터셋 mapping 테이블 저장
|
* 모델 데이터셋 mapping 테이블 저장
|
||||||
*
|
*
|
||||||
* @param modelId 모델학습 id
|
* @param modelId 모델학습 id
|
||||||
* @param datasetList 선택한 data set
|
* @param datasetList 선택한 data set
|
||||||
*/
|
*/
|
||||||
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||||
@@ -197,7 +200,7 @@ public class ModelTrainMngCoreService {
|
|||||||
* 모델학습 config 저장
|
* 모델학습 config 저장
|
||||||
*
|
*
|
||||||
* @param modelId 모델학습 id
|
* @param modelId 모델학습 id
|
||||||
* @param req 요청 파라미터
|
* @param req 요청 파라미터
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||||
@@ -217,7 +220,7 @@ public class ModelTrainMngCoreService {
|
|||||||
/**
|
/**
|
||||||
* 데이터셋 매핑 생성
|
* 데이터셋 매핑 생성
|
||||||
*
|
*
|
||||||
* @param modelUid 모델 UID
|
* @param modelUid 모델 UID
|
||||||
* @param datasetIds 데이터셋 ID 목록
|
* @param datasetIds 데이터셋 ID 목록
|
||||||
*/
|
*/
|
||||||
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
||||||
@@ -239,8 +242,8 @@ public class ModelTrainMngCoreService {
|
|||||||
public ModelMasterEntity findByUuid(UUID uuid) {
|
public ModelMasterEntity findByUuid(UUID uuid) {
|
||||||
try {
|
try {
|
||||||
return modelMngRepository
|
return modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
||||||
}
|
}
|
||||||
@@ -254,9 +257,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public Long findModelIdByUuid(UUID uuid) {
|
public Long findModelIdByUuid(UUID uuid) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.getId();
|
return entity.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,8 +272,8 @@ public class ModelTrainMngCoreService {
|
|||||||
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
||||||
ModelMasterEntity modelEntity = findByUuid(uuid);
|
ModelMasterEntity modelEntity = findByUuid(uuid);
|
||||||
return modelConfigRepository
|
return modelConfigRepository
|
||||||
.findModelConfigByModelId(modelEntity.getId())
|
.findModelConfigByModelId(modelEntity.getId())
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -301,19 +304,21 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(id)
|
.findById(id)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
/**
|
||||||
|
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markInProgress(Long modelId, Long jobId) {
|
public void markInProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
master.setCurrentAttemptId(jobId);
|
master.setCurrentAttemptId(jobId);
|
||||||
@@ -321,46 +326,54 @@ public class ModelTrainMngCoreService {
|
|||||||
// 필요하면 시작시간도 여기서 찍어줌
|
// 필요하면 시작시간도 여기서 찍어줌
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
/**
|
||||||
|
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void clearLastError(Long modelId) {
|
public void clearLastError(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setLastError(null);
|
master.setLastError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
/**
|
||||||
|
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStopped(Long modelId) {
|
public void markStopped(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
/**
|
||||||
|
* 완료 처리(옵션) - Worker가 성공 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markCompleted(Long modelId) {
|
public void markCompleted(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||||
@@ -369,13 +382,15 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
/**
|
||||||
|
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||||
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||||
@@ -387,9 +402,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional
|
@Transactional
|
||||||
public void markSuccess(Long modelId) {
|
public void markSuccess(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
// 모델 상태 완료 처리
|
// 모델 상태 완료 처리
|
||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -417,9 +432,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||||
@@ -437,9 +452,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||||
@@ -457,9 +472,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep1Success(Long modelId) {
|
public void markStep1Success(Long modelId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -476,9 +491,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep2Success(Long modelId) {
|
public void markStep2Success(Long modelId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -489,9 +504,9 @@ public class ModelTrainMngCoreService {
|
|||||||
|
|
||||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setBestEpoch(epoch);
|
entity.setBestEpoch(epoch);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.kamco.cd.training.common.enums.ModelType;
|
|||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import org.springframework.data.domain.Page;
|
import org.springframework.data.domain.Page;
|
||||||
@@ -32,5 +33,5 @@ public interface HyperParamRepositoryCustom {
|
|||||||
|
|
||||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||||
|
|
||||||
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,56 +34,56 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
public Optional<ModelHyperParamEntity> findHyperParamVer() {
|
public Optional<ModelHyperParamEntity> findHyperParamVer() {
|
||||||
|
|
||||||
return Optional.ofNullable(
|
return Optional.ofNullable(
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(modelHyperParamEntity.delYn.isFalse())
|
.where(modelHyperParamEntity.delYn.isFalse())
|
||||||
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
|
public Optional<ModelHyperParamEntity> findHyperParamVerByModelType(ModelType modelType) {
|
||||||
|
|
||||||
return Optional.ofNullable(
|
return Optional.ofNullable(
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(
|
.where(
|
||||||
modelHyperParamEntity
|
modelHyperParamEntity
|
||||||
.delYn
|
.delYn
|
||||||
.isFalse()
|
.isFalse()
|
||||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||||
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
.orderBy(modelHyperParamEntity.hyperVer.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
|
public Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer) {
|
||||||
|
|
||||||
return Optional.ofNullable(
|
return Optional.ofNullable(
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(
|
.where(
|
||||||
modelHyperParamEntity
|
modelHyperParamEntity
|
||||||
.delYn
|
.delYn
|
||||||
.isFalse()
|
.isFalse()
|
||||||
.and(modelHyperParamEntity.hyperVer.eq(hyperVer)))
|
.and(modelHyperParamEntity.hyperVer.eq(hyperVer)))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
|
public Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid) {
|
||||||
return Optional.ofNullable(
|
return Optional.ofNullable(
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
|
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -91,7 +91,9 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
Pageable pageable = req.toPageable();
|
Pageable pageable = req.toPageable();
|
||||||
|
|
||||||
BooleanBuilder builder = new BooleanBuilder();
|
BooleanBuilder builder = new BooleanBuilder();
|
||||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
if (model != null) {
|
||||||
|
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||||
|
}
|
||||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||||
|
|
||||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||||
@@ -117,27 +119,28 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NumberExpression<Long> totalUseCnt =
|
NumberExpression<Long> totalUseCnt =
|
||||||
modelHyperParamEntity
|
modelHyperParamEntity
|
||||||
.m1UseCnt
|
.m1UseCnt
|
||||||
.coalesce(0L)
|
.coalesce(0L)
|
||||||
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
|
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
|
||||||
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
|
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
|
||||||
|
|
||||||
JPAQuery<HyperParamDto.List> query =
|
JPAQuery<HyperParamDto.List> query =
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(
|
.select(
|
||||||
Projections.constructor(
|
Projections.constructor(
|
||||||
HyperParamDto.List.class,
|
HyperParamDto.List.class,
|
||||||
modelHyperParamEntity.uuid,
|
modelHyperParamEntity.uuid,
|
||||||
modelHyperParamEntity.hyperVer,
|
modelHyperParamEntity.modelType.as("model"),
|
||||||
modelHyperParamEntity.createdDttm,
|
modelHyperParamEntity.hyperVer,
|
||||||
modelHyperParamEntity.lastUsedDttm,
|
modelHyperParamEntity.createdDttm,
|
||||||
modelHyperParamEntity.m1UseCnt,
|
modelHyperParamEntity.lastUsedDttm,
|
||||||
modelHyperParamEntity.m2UseCnt,
|
modelHyperParamEntity.m1UseCnt,
|
||||||
modelHyperParamEntity.m3UseCnt,
|
modelHyperParamEntity.m2UseCnt,
|
||||||
totalUseCnt.as("totalUseCnt")))
|
modelHyperParamEntity.m3UseCnt,
|
||||||
.from(modelHyperParamEntity)
|
totalUseCnt.as("totalUseCnt")))
|
||||||
.where(builder);
|
.from(modelHyperParamEntity)
|
||||||
|
.where(builder);
|
||||||
|
|
||||||
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
|
Sort.Order sortOrder = pageable.getSort().stream().findFirst().orElse(null);
|
||||||
|
|
||||||
@@ -149,17 +152,15 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
boolean asc = sortOrder.isAscending();
|
boolean asc = sortOrder.isAscending();
|
||||||
|
|
||||||
switch (property) {
|
switch (property) {
|
||||||
case "createdDttm" ->
|
case "createdDttm" -> query.orderBy(
|
||||||
query.orderBy(
|
asc
|
||||||
asc
|
? modelHyperParamEntity.createdDttm.asc()
|
||||||
? modelHyperParamEntity.createdDttm.asc()
|
: modelHyperParamEntity.createdDttm.desc());
|
||||||
: modelHyperParamEntity.createdDttm.desc());
|
|
||||||
|
|
||||||
case "lastUsedDttm" ->
|
case "lastUsedDttm" -> query.orderBy(
|
||||||
query.orderBy(
|
asc
|
||||||
asc
|
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
|
||||||
|
|
||||||
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
|
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
|
||||||
|
|
||||||
@@ -168,14 +169,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
List<HyperParamDto.List> content =
|
List<HyperParamDto.List> content =
|
||||||
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
|
query.offset(pageable.getOffset()).limit(pageable.getPageSize()).fetch();
|
||||||
|
|
||||||
Long total =
|
Long total =
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity.count())
|
.select(modelHyperParamEntity.count())
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(builder)
|
.where(builder)
|
||||||
.fetchOne();
|
.fetchOne();
|
||||||
|
|
||||||
long totalCount = (total != null) ? total : 0L;
|
long totalCount = (total != null) ? total : 0L;
|
||||||
|
|
||||||
@@ -183,16 +184,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||||
return Optional.ofNullable(
|
return queryFactory
|
||||||
queryFactory
|
.select(modelHyperParamEntity)
|
||||||
.select(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.where(
|
||||||
.where(
|
modelHyperParamEntity
|
||||||
modelHyperParamEntity
|
.delYn.isFalse()
|
||||||
.delYn
|
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||||
.isFalse()
|
.fetch();
|
||||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
|
||||||
.fetchOne());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user