hyperparam_with_modeltype
This commit is contained in:
@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
|
|||||||
private String desc;
|
private String desc;
|
||||||
|
|
||||||
public static ModelType getValueData(String modelNo) {
|
public static ModelType getValueData(String modelNo) {
|
||||||
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
|
return Arrays.stream(ModelType.values())
|
||||||
|
.filter(m -> m.getId().equals(modelNo))
|
||||||
|
.findFirst()
|
||||||
|
.orElse(G1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -182,10 +182,8 @@ public class HyperParamApiController {
|
|||||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||||
})
|
})
|
||||||
@GetMapping("/init/{model}")
|
@GetMapping("/init/{model}")
|
||||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
|
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
|
||||||
@PathVariable ModelType model
|
|
||||||
|
|
||||||
) {
|
|
||||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ public class ModelTrainMngApiController {
|
|||||||
@DeleteMapping("/{uuid}")
|
@DeleteMapping("/{uuid}")
|
||||||
public ApiResponseDto<Void> deleteModelTrain(
|
public ApiResponseDto<Void> deleteModelTrain(
|
||||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||||
@PathVariable UUID uuid) {
|
@PathVariable
|
||||||
|
UUID uuid) {
|
||||||
modelTrainMngService.deleteModelTrain(uuid);
|
modelTrainMngService.deleteModelTrain(uuid);
|
||||||
return ApiResponseDto.ok(null);
|
return ApiResponseDto.ok(null);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,15 +50,15 @@ public class HyperParamCoreService {
|
|||||||
/**
|
/**
|
||||||
* 하이퍼파라미터 수정
|
* 하이퍼파라미터 수정
|
||||||
*
|
*
|
||||||
* @param uuid uuid
|
* @param uuid uuid
|
||||||
* @param createReq 등록 요청
|
* @param createReq 등록 요청
|
||||||
* @return ver
|
* @return ver
|
||||||
*/
|
*/
|
||||||
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
if (entity.getIsDefault()) {
|
if (entity.getIsDefault()) {
|
||||||
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
|
|||||||
return entity.getHyperVer();
|
return entity.getHyperVer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 하이퍼파라미터 삭제
|
* 하이퍼파라미터 삭제
|
||||||
*
|
*
|
||||||
@@ -80,15 +79,15 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public void deleteHyperParam(UUID uuid) {
|
public void deleteHyperParam(UUID uuid) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
|
|
||||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//디폴트면 삭제불가
|
// 디폴트면 삭제불가
|
||||||
if (entity.getIsDefault()) {
|
if (entity.getIsDefault()) {
|
||||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||||
}
|
}
|
||||||
@@ -105,9 +104,9 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.getHyperparamByType(model)
|
.getHyperparamByType(model)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,9 +117,9 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
||||||
ModelHyperParamEntity entity =
|
ModelHyperParamEntity entity =
|
||||||
hyperParamRepository
|
hyperParamRepository
|
||||||
.findHyperParamByUuid(uuid)
|
.findHyperParamByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,16 +142,16 @@ public class HyperParamCoreService {
|
|||||||
*/
|
*/
|
||||||
public String getFirstHyperParamVersion(ModelType model) {
|
public String getFirstHyperParamVersion(ModelType model) {
|
||||||
return hyperParamRepository
|
return hyperParamRepository
|
||||||
.findHyperParamVerByModelType(model)
|
.findHyperParamVerByModelType(model)
|
||||||
.map(ModelHyperParamEntity::getHyperVer)
|
.map(ModelHyperParamEntity::getHyperVer)
|
||||||
.map(ver -> increase(ver, model))
|
.map(ver -> increase(ver, model))
|
||||||
.orElse(model.name() + "_000001");
|
.orElse(model.name() + "_000001");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 하이퍼 파라미터의 버전을 증가시킨다.
|
* 하이퍼 파라미터의 버전을 증가시킨다.
|
||||||
*
|
*
|
||||||
* @param hyperVer 현재 버전
|
* @param hyperVer 현재 버전
|
||||||
* @param modelType 모델 타입
|
* @param modelType 모델 타입
|
||||||
* @return 증가된 버전
|
* @return 증가된 버전
|
||||||
*/
|
*/
|
||||||
@@ -214,5 +213,4 @@ public class HyperParamCoreService {
|
|||||||
// memo
|
// memo
|
||||||
entity.setMemo(src.getMemo());
|
entity.setMemo(src.getMemo());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public void deleteModel(UUID uuid) {
|
public void deleteModel(UUID uuid) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
entity.setDelYn(true);
|
entity.setDelYn(true);
|
||||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||||
entity.setUpdatedUid(userUtil.getId());
|
entity.setUpdatedUid(userUtil.getId());
|
||||||
@@ -83,15 +83,15 @@ public class ModelTrainMngCoreService {
|
|||||||
ModelMasterEntity entity = new ModelMasterEntity();
|
ModelMasterEntity entity = new ModelMasterEntity();
|
||||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||||
|
|
||||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
||||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
hyperParamEntity =
|
hyperParamEntity =
|
||||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||||
@@ -99,8 +99,8 @@ public class ModelTrainMngCoreService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String modelVer =
|
String modelVer =
|
||||||
String.join(
|
String.join(
|
||||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||||
entity.setModelVer(modelVer);
|
entity.setModelVer(modelVer);
|
||||||
entity.setHyperParamId(hyperParamEntity.getId());
|
entity.setHyperParamId(hyperParamEntity.getId());
|
||||||
entity.setModelNo(addReq.getModelNo());
|
entity.setModelNo(addReq.getModelNo());
|
||||||
@@ -132,7 +132,7 @@ public class ModelTrainMngCoreService {
|
|||||||
* data set 저장
|
* data set 저장
|
||||||
*
|
*
|
||||||
* @param modelId 저장한 모델 학습 id
|
* @param modelId 저장한 모델 학습 id
|
||||||
* @param addReq 요청 파라미터
|
* @param addReq 요청 파라미터
|
||||||
*/
|
*/
|
||||||
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||||
TrainingDataset dataset = addReq.getTrainingDataset();
|
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||||
@@ -165,9 +165,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||||
entity.setRequestPath(req.getRequestPath());
|
entity.setRequestPath(req.getRequestPath());
|
||||||
}
|
}
|
||||||
@@ -180,7 +180,7 @@ public class ModelTrainMngCoreService {
|
|||||||
/**
|
/**
|
||||||
* 모델 데이터셋 mapping 테이블 저장
|
* 모델 데이터셋 mapping 테이블 저장
|
||||||
*
|
*
|
||||||
* @param modelId 모델학습 id
|
* @param modelId 모델학습 id
|
||||||
* @param datasetList 선택한 data set
|
* @param datasetList 선택한 data set
|
||||||
*/
|
*/
|
||||||
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||||
@@ -197,7 +197,7 @@ public class ModelTrainMngCoreService {
|
|||||||
* 모델학습 config 저장
|
* 모델학습 config 저장
|
||||||
*
|
*
|
||||||
* @param modelId 모델학습 id
|
* @param modelId 모델학습 id
|
||||||
* @param req 요청 파라미터
|
* @param req 요청 파라미터
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||||
@@ -217,7 +217,7 @@ public class ModelTrainMngCoreService {
|
|||||||
/**
|
/**
|
||||||
* 데이터셋 매핑 생성
|
* 데이터셋 매핑 생성
|
||||||
*
|
*
|
||||||
* @param modelUid 모델 UID
|
* @param modelUid 모델 UID
|
||||||
* @param datasetIds 데이터셋 ID 목록
|
* @param datasetIds 데이터셋 ID 목록
|
||||||
*/
|
*/
|
||||||
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
||||||
@@ -239,8 +239,8 @@ public class ModelTrainMngCoreService {
|
|||||||
public ModelMasterEntity findByUuid(UUID uuid) {
|
public ModelMasterEntity findByUuid(UUID uuid) {
|
||||||
try {
|
try {
|
||||||
return modelMngRepository
|
return modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
||||||
}
|
}
|
||||||
@@ -254,9 +254,9 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public Long findModelIdByUuid(UUID uuid) {
|
public Long findModelIdByUuid(UUID uuid) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findByUuid(uuid)
|
.findByUuid(uuid)
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
return entity.getId();
|
return entity.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,8 +269,8 @@ public class ModelTrainMngCoreService {
|
|||||||
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
||||||
ModelMasterEntity modelEntity = findByUuid(uuid);
|
ModelMasterEntity modelEntity = findByUuid(uuid);
|
||||||
return modelConfigRepository
|
return modelConfigRepository
|
||||||
.findModelConfigByModelId(modelEntity.getId())
|
.findModelConfigByModelId(modelEntity.getId())
|
||||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -301,21 +301,19 @@ public class ModelTrainMngCoreService {
|
|||||||
*/
|
*/
|
||||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(id)
|
.findById(id)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||||
return entity.toDto();
|
return entity.toDto();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||||
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markInProgress(Long modelId, Long jobId) {
|
public void markInProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
master.setCurrentAttemptId(jobId);
|
master.setCurrentAttemptId(jobId);
|
||||||
@@ -323,54 +321,46 @@ public class ModelTrainMngCoreService {
|
|||||||
// 필요하면 시작시간도 여기서 찍어줌
|
// 필요하면 시작시간도 여기서 찍어줌
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||||
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void clearLastError(Long modelId) {
|
public void clearLastError(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setLastError(null);
|
master.setLastError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||||
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStopped(Long modelId) {
|
public void markStopped(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||||
* 완료 처리(옵션) - Worker가 성공 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markCompleted(Long modelId) {
|
public void markCompleted(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markError(Long modelId, String errorMessage) {
|
public void markError(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||||
@@ -379,15 +369,13 @@ public class ModelTrainMngCoreService {
|
|||||||
master.setUpdatedDttm(ZonedDateTime.now());
|
master.setUpdatedDttm(ZonedDateTime.now());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
|
||||||
*/
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markStep2Error(Long modelId, String errorMessage) {
|
public void markStep2Error(Long modelId, String errorMessage) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||||
@@ -399,9 +387,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional
|
@Transactional
|
||||||
public void markSuccess(Long modelId) {
|
public void markSuccess(Long modelId) {
|
||||||
ModelMasterEntity master =
|
ModelMasterEntity master =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
// 모델 상태 완료 처리
|
// 모델 상태 완료 처리
|
||||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -429,9 +417,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||||
@@ -449,9 +437,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||||
@@ -469,9 +457,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep1Success(Long modelId) {
|
public void markStep1Success(Long modelId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -488,9 +476,9 @@ public class ModelTrainMngCoreService {
|
|||||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||||
public void markStep2Success(Long modelId) {
|
public void markStep2Success(Long modelId) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||||
@@ -501,9 +489,9 @@ public class ModelTrainMngCoreService {
|
|||||||
|
|
||||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||||
ModelMasterEntity entity =
|
ModelMasterEntity entity =
|
||||||
modelMngRepository
|
modelMngRepository
|
||||||
.findById(modelId)
|
.findById(modelId)
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||||
|
|
||||||
entity.setBestEpoch(epoch);
|
entity.setBestEpoch(epoch);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
|
|||||||
@Enumerated(EnumType.STRING)
|
@Enumerated(EnumType.STRING)
|
||||||
private ModelType modelType;
|
private ModelType modelType;
|
||||||
|
|
||||||
|
|
||||||
@Column(name = "default_param")
|
@Column(name = "default_param")
|
||||||
private Boolean isDefault = false;
|
private Boolean isDefault = false;
|
||||||
|
|
||||||
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
|
|||||||
// -------------------------
|
// -------------------------
|
||||||
this.gpuCnt,
|
this.gpuCnt,
|
||||||
this.gpuIds,
|
this.gpuIds,
|
||||||
this.masterPort
|
this.masterPort,
|
||||||
, this.isDefault
|
this.isDefault);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,10 +185,14 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
|||||||
@Override
|
@Override
|
||||||
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||||
return Optional.ofNullable(
|
return Optional.ofNullable(
|
||||||
queryFactory
|
queryFactory
|
||||||
.select(modelHyperParamEntity)
|
.select(modelHyperParamEntity)
|
||||||
.from(modelHyperParamEntity)
|
.from(modelHyperParamEntity)
|
||||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
|
.where(
|
||||||
.fetchOne());
|
modelHyperParamEntity
|
||||||
|
.delYn
|
||||||
|
.isFalse()
|
||||||
|
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||||
|
.fetchOne());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user