hyperparam_with_modeltype
This commit is contained in:
@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
|
||||
private String desc;
|
||||
|
||||
public static ModelType getValueData(String modelNo) {
|
||||
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
|
||||
return Arrays.stream(ModelType.values())
|
||||
.filter(m -> m.getId().equals(modelNo))
|
||||
.findFirst()
|
||||
.orElse(G1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -182,10 +182,8 @@ public class HyperParamApiController {
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/init/{model}")
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
|
||||
@PathVariable ModelType model
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
|
||||
|
||||
) {
|
||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +79,8 @@ public class ModelTrainMngApiController {
|
||||
@DeleteMapping("/{uuid}")
|
||||
public ApiResponseDto<Void> deleteModelTrain(
|
||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||
@PathVariable UUID uuid) {
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
modelTrainMngService.deleteModelTrain(uuid);
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
|
||||
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
|
||||
return entity.getHyperVer();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 삭제
|
||||
*
|
||||
@@ -84,11 +83,11 @@ public class HyperParamCoreService {
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
// }
|
||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
// }
|
||||
|
||||
//디폴트면 삭제불가
|
||||
// 디폴트면 삭제불가
|
||||
if (entity.getIsDefault()) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
@@ -214,5 +213,4 @@ public class HyperParamCoreService {
|
||||
// memo
|
||||
entity.setMemo(src.getMemo());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -83,11 +83,11 @@ public class ModelTrainMngCoreService {
|
||||
ModelMasterEntity entity = new ModelMasterEntity();
|
||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||
|
||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
|
||||
} else {
|
||||
hyperParamEntity =
|
||||
@@ -307,9 +307,7 @@ public class ModelTrainMngCoreService {
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
||||
*/
|
||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||
@Transactional
|
||||
public void markInProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity master =
|
||||
@@ -323,9 +321,7 @@ public class ModelTrainMngCoreService {
|
||||
// 필요하면 시작시간도 여기서 찍어줌
|
||||
}
|
||||
|
||||
/**
|
||||
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
||||
*/
|
||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||
@Transactional
|
||||
public void clearLastError(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
@@ -336,9 +332,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
||||
*/
|
||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||
@Transactional
|
||||
public void markStopped(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
@@ -349,9 +343,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* 완료 처리(옵션) - Worker가 성공 시 호출
|
||||
*/
|
||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||
@Transactional
|
||||
public void markCompleted(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
@@ -362,9 +354,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
@@ -379,9 +369,7 @@ public class ModelTrainMngCoreService {
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
|
||||
@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
|
||||
@Enumerated(EnumType.STRING)
|
||||
private ModelType modelType;
|
||||
|
||||
|
||||
@Column(name = "default_param")
|
||||
private Boolean isDefault = false;
|
||||
|
||||
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
|
||||
// -------------------------
|
||||
this.gpuCnt,
|
||||
this.gpuIds,
|
||||
this.masterPort
|
||||
, this.isDefault
|
||||
);
|
||||
this.masterPort,
|
||||
this.isDefault);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.where(
|
||||
modelHyperParamEntity
|
||||
.delYn
|
||||
.isFalse()
|
||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.fetchOne());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user