hyperparam_with_modeltype
This commit is contained in:
@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
|
|||||||
private String desc;
|
private String desc;
|
||||||
|
|
||||||
public static ModelType getValueData(String modelNo) {
|
public static ModelType getValueData(String modelNo) {
|
||||||
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
|
return Arrays.stream(ModelType.values())
|
||||||
|
.filter(m -> m.getId().equals(modelNo))
|
||||||
|
.findFirst()
|
||||||
|
.orElse(G1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -182,10 +182,8 @@ public class HyperParamApiController {
|
|||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@GetMapping("/init/{model}")
|
@GetMapping("/init/{model}")
|
||||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
|
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
|
||||||
@PathVariable ModelType model
|
|
||||||
|
|
||||||
) {
|
|
||||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ public class ModelTrainMngApiController {
|
|||||||
@DeleteMapping("/{uuid}")
|
@DeleteMapping("/{uuid}")
|
||||||
public ApiResponseDto<Void> deleteModelTrain(
|
public ApiResponseDto<Void> deleteModelTrain(
|
||||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||||
@PathVariable UUID uuid) {
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
modelTrainMngService.deleteModelTrain(uuid);
|
modelTrainMngService.deleteModelTrain(uuid);
|
||||||
return ApiResponseDto.ok(null);
|
return ApiResponseDto.ok(null);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
|
|||||||
return entity.getHyperVer();
|
return entity.getHyperVer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 하이퍼파라미터 삭제
|
* 하이퍼파라미터 삭제
|
||||||
*
|
*
|
||||||
@@ -84,11 +83,11 @@ public class HyperParamCoreService {
|
|||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//디폴트면 삭제불가
|
// 디폴트면 삭제불가
|
||||||
if (entity.getIsDefault()) {
|
if (entity.getIsDefault()) {
|
||||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
}
|
}
|
||||||
@@ -214,5 +213,4 @@ public class HyperParamCoreService {
|
|||||||
// memo
|
// memo
|
||||||
entity.setMemo(src.getMemo());
|
entity.setMemo(src.getMemo());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,11 +83,11 @@ public class ModelTrainMngCoreService {
|
|||||||
ModelMasterEntity entity = new ModelMasterEntity();
|
ModelMasterEntity entity = new ModelMasterEntity();
|
||||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||||
|
|
||||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
||||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
hyperParamEntity =
|
hyperParamEntity =
|
||||||
@@ -307,9 +307,7 @@ public class ModelTrainMngCoreService {
|
|||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||||
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markInProgress(Long modelId, Long jobId) {
|
public void markInProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -323,9 +321,7 @@ public class ModelTrainMngCoreService {
|
|||||||
// 필요하면 시작시간도 여기서 찍어줌
|
// 필요하면 시작시간도 여기서 찍어줌
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||||
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void clearLastError(Long modelId) {
|
public void clearLastError(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -336,9 +332,7 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setLastError(null);
|
master.setLastError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||||
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStopped(Long modelId) {
|
public void markStopped(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -349,9 +343,7 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||||
* 완료 처리(옵션) - Worker가 성공 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markCompleted(Long modelId) {
|
public void markCompleted(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -362,9 +354,7 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
@@ -379,9 +369,7 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
|
|||||||
@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
|
|||||||
@Enumerated(EnumType.STRING)
|
@Enumerated(EnumType.STRING)
|
||||||
private ModelType modelType;
|
private ModelType modelType;
|
||||||
|
|
||||||
|
|
||||||
@Column(name = "default_param")
|
@Column(name = "default_param")
|
||||||
private Boolean isDefault = false;
|
private Boolean isDefault = false;
|
||||||
|
|
||||||
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
|
|||||||
// -------------------------
|
// -------------------------
|
||||||
this.gpuCnt,
|
this.gpuCnt,
|
||||||
this.gpuIds,
|
this.gpuIds,
|
||||||
this.masterPort
|
this.masterPort,
|
||||||
, this.isDefault
|
this.isDefault);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -188,7 +188,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
|
.where(
|
||||||
|
modelHyperParamEntity
|
||||||
|
.delYn
|
||||||
|
.isFalse()
|
||||||
|
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||||
.fetchOne());
|
.fetchOne());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user