feat/training_260202 #135

Merged
teddy merged 4 commits from feat/training_260202 into develop 2026-02-23 14:31:12 +09:00
16 changed files with 151 additions and 74 deletions

View File

@@ -12,13 +12,14 @@ import lombok.Setter;
@AllArgsConstructor
@NoArgsConstructor
public class HyperParam {
@Schema(description = "모델", example = "G1")
private ModelType model; // G1, G2, G3
// -------------------------
// Important
// -------------------------
@Schema(description = "모델", example = "large")
private ModelType model; // backbone
@Schema(description = "백본 네트워크", example = "large")
private String backbone; // backbone

View File

@@ -115,10 +115,7 @@ public class HyperParamDto {
@JsonFormatDttm private ZonedDateTime createDttm;
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
private String memo;
private Long m1UseCnt;
private Long m2UseCnt;
private Long m3UseCnt;
private Long totalCnt;
private Long totalUseCnt;
}
@Getter

View File

@@ -104,7 +104,7 @@ public class HyperParamCoreService {
*/
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
ModelHyperParamEntity entity =
hyperParamRepository.getHyperparamByType(model).stream()
hyperParamRepository.getHyperParamByType(model).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));

View File

@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
return modelDetailRepository.getModelDetailSummary(uuid);
}
/**
* 하이퍼 파리미터 요약정보
*
* @param uuid 모델마스터 uuid
* @return
*/
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
return modelDetailRepository.getByModelHyperParamSummary(uuid);
}

View File

@@ -52,7 +52,12 @@ public class ModelTrainJobCoreService {
/** 실행 시작 처리 */
@Transactional
public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
Long jobId,
String containerName,
String logPath,
String lockedBy,
Integer totalEpoch,
String jobType) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
@@ -64,13 +69,19 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
job.setJobType(jobType);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/** 성공 처리 */
/**
* 성공 처리
*
* @param jobId
* @param exitCode
*/
@Transactional
public void markSuccess(Long jobId, int exitCode) {
ModelTrainJobEntity job =
@@ -83,7 +94,13 @@ public class ModelTrainJobCoreService {
job.setFinishedDttm(ZonedDateTime.now());
}
/** 실패 처리 */
/**
* 실패 처리
*
* @param jobId
* @param exitCode
* @param errorMessage
*/
@Transactional
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
ModelTrainJobEntity job =

View File

@@ -90,7 +90,7 @@ public class ModelTrainMngCoreService {
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
hyperParamEntity =
hyperParamRepository.getHyperparamByType(modelType).stream()
hyperParamRepository.getHyperParamByType(modelType).stream()
.filter(e -> e.getIsDefault() == Boolean.TRUE)
.findFirst()
.orElse(null);
@@ -384,7 +384,12 @@ public class ModelTrainMngCoreService {
master.setStatusCd(TrainStatusType.COMPLETED.getId());
}
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
/**
* step 1오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markError(Long modelId, String errorMessage) {
ModelMasterEntity master =
@@ -399,7 +404,12 @@ public class ModelTrainMngCoreService {
master.setUpdatedDttm(ZonedDateTime.now());
}
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
/**
* step 2오류 처리(옵션) - Worker가 실패 시 호출
*
* @param modelId
* @param errorMessage
*/
@Transactional
public void markStep2Error(Long modelId, String errorMessage) {
ModelMasterEntity master =

View File

@@ -303,15 +303,6 @@ public class ModelHyperParamEntity {
@Column(name = "last_used_dttm")
private ZonedDateTime lastUsedDttm;
@Column(name = "m1_use_cnt")
private Long m1UseCnt = 0L;
@Column(name = "m2_use_cnt")
private Long m2UseCnt = 0L;
@Column(name = "m3_use_cnt")
private Long m3UseCnt = 0L;
@Column(name = "model_type")
@Enumerated(EnumType.STRING)
private ModelType modelType;

View File

@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
@Column(name = "current_epoch")
private Integer currentEpoch;
@Column(name = "job_type")
private String jobType;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,

View File

@@ -29,9 +29,28 @@ public interface HyperParamRepositoryCustom {
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
/**
* 하이퍼 파라미터 상세조회
*
* @param uuid
* @return
*/
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
/**
* 하이퍼 파라미터 목록 조회
*
* @param model
* @param req
* @return
*/
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
/**
* 하이퍼 파라미터 모델타입으로 조회
*
* @param modelType
* @return
*/
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
}

View File

@@ -9,7 +9,6 @@ import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.NumberExpression;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.time.ZoneId;
@@ -82,7 +81,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
.where(modelHyperParamEntity.uuid.eq(uuid))
.fetchOne());
}
@@ -91,10 +90,12 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
Pageable pageable = req.toPageable();
BooleanBuilder builder = new BooleanBuilder();
builder.and(modelHyperParamEntity.delYn.isFalse());
if (model != null) {
builder.and(modelHyperParamEntity.modelType.eq(model));
}
builder.and(modelHyperParamEntity.delYn.isFalse());
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
// 버전
@@ -118,13 +119,6 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
}
NumberExpression<Long> totalUseCnt =
modelHyperParamEntity
.m1UseCnt
.coalesce(0L)
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
JPAQuery<HyperParamDto.List> query =
queryFactory
.select(
@@ -135,11 +129,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
modelHyperParamEntity.hyperVer,
modelHyperParamEntity.createdDttm,
modelHyperParamEntity.lastUsedDttm,
modelHyperParamEntity.memo,
modelHyperParamEntity.m1UseCnt,
modelHyperParamEntity.m2UseCnt,
modelHyperParamEntity.m3UseCnt,
totalUseCnt.as("totalUseCnt")))
modelHyperParamEntity.memo))
.from(modelHyperParamEntity)
.where(builder);
@@ -165,8 +155,6 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
? modelHyperParamEntity.lastUsedDttm.asc()
: modelHyperParamEntity.lastUsedDttm.desc());
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
}
}
@@ -187,7 +175,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
}
@Override
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
return queryFactory
.select(modelHyperParamEntity)
.from(modelHyperParamEntity)

View File

@@ -52,7 +52,14 @@ public class DockerTrainService {
private final ModelTrainJobCoreService modelTrainJobCoreService;
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param req
* @param containerName
* @return
* @throws Exception
*/
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
List<String> cmd = buildDockerRunCommand(containerName, req);
@@ -267,8 +274,7 @@ public class DockerTrainService {
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());
addArg(c, "--batch-size", req.getBatchSize());
addArg(c, "--gpu-ids", req.getGpuIds());
// addArg(c, "--gpus", req.getGpus());
addArg(c, "--gpu-ids", req.getGpuIds()); // null
addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs());
@@ -342,6 +348,14 @@ public class DockerTrainService {
}
}
/**
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
*
* @param containerName
* @param req
* @return
* @throws Exception
*/
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
List<String> cmd = buildDockerEvalCommand(containerName, req);

View File

@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
@Value("${file.pt-path}")
private String ptPathDir;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
public void findTestValidMetricCsvFiles() throws IOException {
// if (isLocalProfile()) {
// return;
// }
/** 결과 csv 파일 정보 등록 */
public void findTestValidMetricCsvFiles() {
List<ResponsePathDto> modelIds =
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();

View File

@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
@Value("${train.docker.responseDir}")
private String responseDir;
/**
* 실행중인 profile
*
* @return
*/
private boolean isLocalProfile() {
return "local".equalsIgnoreCase(profile);
}
// @Scheduled(cron = "0 * * * * *")
/** 결과 csv 파일 정보 등록 */
public void findTrainValidMetricCsvFiles() {
// if (isLocalProfile()) {
// return;
// }
List<ResponsePathDto> modelIds =
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();

View File

@@ -23,6 +23,14 @@ public class TestJobService {
private final ApplicationEventPublisher eventPublisher;
private final DataSetCountersService dataSetCounters;
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @param uuid
* @param epoch
* @return
*/
@Transactional
public Long enqueue(Long modelId, UUID uuid, int epoch) {
@@ -58,6 +66,11 @@ public class TestJobService {
return jobId;
}
/**
* 취소
*
* @param modelId
*/
@Transactional
public void cancel(Long modelId) {

View File

@@ -47,7 +47,12 @@ public class TrainJobService {
return modelTrainMngCoreService.findModelIdByUuid(uuid);
}
/** 실행 예약 (QUEUE 등록) */
/**
* 실행 예약 (QUEUE 등록)
*
* @param modelId
* @return
*/
@Transactional
public Long enqueue(Long modelId) {
@@ -139,6 +144,13 @@ public class TrainJobService {
modelTrainMngCoreService.markStopped(modelId);
}
/**
* 학습 이어하기
*
* @param modelId 모델 id
* @param mode NONE 새로 시작, REQUIRE 이어하기
* @return
*/
private Long createNextAttempt(Long modelId, ResumeMode mode) {
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
@@ -189,6 +201,12 @@ public class TrainJobService {
REQUIRE // 이어하기
}
/**
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
*
* @param paramsJson
* @return
*/
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
if (paramsJson == null) return null;
@@ -230,6 +248,12 @@ public class TrainJobService {
}
}
/**
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
*
* @param modelUuid
* @return
*/
@Transactional
public UUID createTmpFile(UUID modelUuid) {
UUID tmpUuid = UUID.randomUUID();

View File

@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
/** job 실행 */
@Log4j2
@Component
@RequiredArgsConstructor
@@ -54,6 +55,8 @@ public class TrainJobWorker {
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
String type = isEval ? "TEST" : "TRAIN";
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
@@ -61,12 +64,15 @@ public class TrainJobWorker {
}
}
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
// 실행 시작 처리
modelTrainJobCoreService.markRunning(
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
log.info("[JOB] markRunning done jobId={}", jobId);
try {
TrainRunResult result;
if (isEval) {
// step2 진행중 처리
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
String uuid = String.valueOf(params.get("uuid"));
int epoch = (int) params.get("epoch");
@@ -81,11 +87,14 @@ public class TrainJobWorker {
evalReq.setOutputFolder(outputFolder);
log.info("[JOB] selected test epoch={}", epoch);
// 도커 실행 후 로그 수집
result = dockerTrainService.runEvalSync(containerName, evalReq);
} else {
// step1 진행중 처리
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
TrainRunRequest trainReq = toTrainRunRequest(params);
// 도커 실행 후 로그 수집
result = dockerTrainService.runTrainSync(trainReq, containerName);
}
@@ -99,24 +108,31 @@ public class TrainJobWorker {
}
if (result.getExitCode() == 0) {
// 성공 처리
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
if (isEval) {
// step2 완료처리
modelTrainMngCoreService.markStep2Success(modelId);
// 결과 csv 파일 정보 등록
modelTestMetricsJobService.findTestValidMetricCsvFiles();
} else {
modelTrainMngCoreService.markStep1Success(modelId);
// 결과 csv 파일 정보 등록
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
}
} else {
String failMsg = result.getStatus() + "\n" + result.getLogs();
// 실패 처리
modelTrainJobCoreService.markFailed(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
if (isEval) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
}
}
@@ -125,8 +141,10 @@ public class TrainJobWorker {
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
if ("EVAL".equals(params.get("jobType"))) {
// 오류 정보 등록
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
} else {
// 오류 정보 등록
modelTrainMngCoreService.markError(modelId, e.getMessage());
}
}