feat/training_260202 #17

Merged
gina merged 6 commits from feat/training_260202 into develop 2026-02-04 19:54:24 +09:00
7 changed files with 120 additions and 8 deletions
Showing only changes of commit 474a3c119e - Show all commits

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.model;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.service.ModelTrainMngService;
@@ -91,4 +92,22 @@ public class ModelTrainMngApiController {
modelTrainMngService.createModelTrain(req);
return ApiResponseDto.ok("ok");
}
@Operation(summary = "모델학습 config 정보 조회", description = "모델학습 config 정보 조회 API")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "검색 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = ModelConfigDto.Basic.class))),
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/config/{uuid}")
public ApiResponseDto<ModelConfigDto.Basic> updateModelTrain(@PathVariable UUID uuid) {
return ApiResponseDto.ok(modelTrainMngService.getModelConfigByModelId(uuid));
}
}

View File

@@ -0,0 +1,23 @@
package com.kamco.cd.training.model.dto;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
public class ModelConfigDto {
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public static class Basic {
private Long configId;
private Long modelId;
private Integer epochCount;
private Float trainPercent;
private Float validationPercent;
private Float testPercent;
private String memo;
}
}

View File

@@ -3,6 +3,7 @@ package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.dto.HyperParam;
import com.kamco.cd.training.common.enums.HyperParamSelectType;
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
@@ -38,6 +39,7 @@ public class ModelTrainMngService {
*
* @param uuid
*/
@Transactional
public void deleteModelTrain(UUID uuid) {
modelTrainMngCoreService.deleteModel(uuid);
}
@@ -73,4 +75,14 @@ public class ModelTrainMngService {
// 모델 config 저장
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
}
/**
* 모델학습 모델학습 uuid config정보 조회
*
* @param uuid 모델학습 uuid
* @return
*/
public ModelConfigDto.Basic getModelConfigByModelId(UUID uuid) {
return modelTrainMngCoreService.findModelConfigByModelId(uuid);
}
}

View File

@@ -5,8 +5,8 @@ import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.enums.TrainStatusType;
import com.kamco.cd.training.common.exception.BadRequestException;
import com.kamco.cd.training.common.exception.CustomApiException;
import com.kamco.cd.training.common.exception.NotFoundException;
import com.kamco.cd.training.common.utils.UserUtil;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
import com.kamco.cd.training.model.dto.ModelTrainMngDto.TrainingDataset;
@@ -162,7 +162,7 @@ public class ModelTrainMngCoreService {
* @param req 요청 파라미터
* @return
*/
public Long saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
ModelMasterEntity modelMasterEntity = new ModelMasterEntity();
ModelConfigEntity entity = new ModelConfigEntity();
modelMasterEntity.setId(modelId);
@@ -173,7 +173,7 @@ public class ModelTrainMngCoreService {
entity.setTestPercent(req.getTestCnt());
entity.setMemo(req.getMemo());
return modelConfigRepository.save(entity).getId();
modelConfigRepository.save(entity);
}
/**
@@ -198,14 +198,26 @@ public class ModelTrainMngCoreService {
* @param uuid UUID
* @return 모델 Entity
*/
public ModelMasterEntity findByUuid(String uuid) {
public ModelMasterEntity findByUuid(UUID uuid) {
try {
java.util.UUID uuidObj = java.util.UUID.fromString(uuid);
return modelMngRepository
.findByUuid(uuidObj)
.orElseThrow(() -> new NotFoundException("모델을 찾을 수 없습니다. UUID: " + uuid));
.findByUuid(uuid)
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
} catch (IllegalArgumentException e) {
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
}
}
/**
* 모델학습 아이디로 config정보 조회
*
* @param uuid
* @return
*/
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
ModelMasterEntity modelEntity = findByUuid(uuid);
return modelConfigRepository
.findModelConfigByModelId(modelEntity.getId())
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
}
}

View File

@@ -3,4 +3,5 @@ package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelConfigEntity;
import org.springframework.data.jpa.repository.JpaRepository;
public interface ModelConfigRepository extends JpaRepository<ModelConfigEntity, Long> {}
public interface ModelConfigRepository
extends JpaRepository<ModelConfigEntity, Long>, ModelConfigRepositoryCustom {}

View File

@@ -0,0 +1,8 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.model.dto.ModelConfigDto;
import java.util.Optional;
public interface ModelConfigRepositoryCustom {
Optional<ModelConfigDto.Basic> findModelConfigByModelId(Long modelId);
}

View File

@@ -0,0 +1,37 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QModelConfigEntity.modelConfigEntity;
import com.kamco.cd.training.model.dto.ModelConfigDto.Basic;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Repository;
@Repository
@RequiredArgsConstructor
public class ModelConfigRepositoryImpl implements ModelConfigRepositoryCustom {
private final JPAQueryFactory queryFactory;
@Override
public Optional<Basic> findModelConfigByModelId(Long modelId) {
return Optional.ofNullable(
queryFactory
.select(
Projections.constructor(
Basic.class,
modelConfigEntity.id,
modelConfigEntity.model.id,
modelConfigEntity.epochCount,
modelConfigEntity.trainPercent,
modelConfigEntity.validationPercent,
modelConfigEntity.testPercent,
modelConfigEntity.memo))
.from(modelConfigEntity)
.where(modelConfigEntity.model.id.eq(modelId))
.fetchOne());
}
}