Merge pull request 'feat/training_260202' (#135) from feat/training_260202 into develop
Reviewed-on: #135
This commit was merged in pull request #135.
This commit is contained in:
@@ -12,13 +12,14 @@ import lombok.Setter;
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class HyperParam {
|
||||
|
||||
@Schema(description = "모델", example = "G1")
|
||||
private ModelType model; // G1, G2, G3
|
||||
|
||||
// -------------------------
|
||||
// Important
|
||||
// -------------------------
|
||||
|
||||
@Schema(description = "모델", example = "large")
|
||||
private ModelType model; // backbone
|
||||
|
||||
@Schema(description = "백본 네트워크", example = "large")
|
||||
private String backbone; // backbone
|
||||
|
||||
|
||||
@@ -115,10 +115,7 @@ public class HyperParamDto {
|
||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||
private String memo;
|
||||
private Long m1UseCnt;
|
||||
private Long m2UseCnt;
|
||||
private Long m3UseCnt;
|
||||
private Long totalCnt;
|
||||
private Long totalUseCnt;
|
||||
}
|
||||
|
||||
@Getter
|
||||
|
||||
@@ -104,7 +104,7 @@ public class HyperParamCoreService {
|
||||
*/
|
||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository.getHyperparamByType(model).stream()
|
||||
hyperParamRepository.getHyperParamByType(model).stream()
|
||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
@@ -57,6 +57,12 @@ public class ModelTrainDetailCoreService {
|
||||
return modelDetailRepository.getModelDetailSummary(uuid);
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼 파리미터 요약정보
|
||||
*
|
||||
* @param uuid 모델마스터 uuid
|
||||
* @return
|
||||
*/
|
||||
public HyperSummary getByModelHyperParamSummary(UUID uuid) {
|
||||
return modelDetailRepository.getByModelHyperParamSummary(uuid);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,12 @@ public class ModelTrainJobCoreService {
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(
|
||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
||||
Long jobId,
|
||||
String containerName,
|
||||
String logPath,
|
||||
String lockedBy,
|
||||
Integer totalEpoch,
|
||||
String jobType) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
@@ -64,13 +69,19 @@ public class ModelTrainJobCoreService {
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
job.setJobType(jobType);
|
||||
|
||||
if (totalEpoch != null) {
|
||||
job.setTotalEpoch(totalEpoch);
|
||||
}
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
/**
|
||||
* 성공 처리
|
||||
*
|
||||
* @param jobId
|
||||
* @param exitCode
|
||||
*/
|
||||
@Transactional
|
||||
public void markSuccess(Long jobId, int exitCode) {
|
||||
ModelTrainJobEntity job =
|
||||
@@ -83,7 +94,13 @@ public class ModelTrainJobCoreService {
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** 실패 처리 */
|
||||
/**
|
||||
* 실패 처리
|
||||
*
|
||||
* @param jobId
|
||||
* @param exitCode
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markFailed(Long jobId, Integer exitCode, String errorMessage) {
|
||||
ModelTrainJobEntity job =
|
||||
|
||||
@@ -90,7 +90,7 @@ public class ModelTrainMngCoreService {
|
||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||
hyperParamEntity =
|
||||
hyperParamRepository.getHyperparamByType(modelType).stream()
|
||||
hyperParamRepository.getHyperParamByType(modelType).stream()
|
||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
@@ -384,7 +384,12 @@ public class ModelTrainMngCoreService {
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/**
|
||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*
|
||||
* @param modelId
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
@@ -399,7 +404,12 @@ public class ModelTrainMngCoreService {
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
/**
|
||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*
|
||||
* @param modelId
|
||||
* @param errorMessage
|
||||
*/
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
|
||||
@@ -303,15 +303,6 @@ public class ModelHyperParamEntity {
|
||||
@Column(name = "last_used_dttm")
|
||||
private ZonedDateTime lastUsedDttm;
|
||||
|
||||
@Column(name = "m1_use_cnt")
|
||||
private Long m1UseCnt = 0L;
|
||||
|
||||
@Column(name = "m2_use_cnt")
|
||||
private Long m2UseCnt = 0L;
|
||||
|
||||
@Column(name = "m3_use_cnt")
|
||||
private Long m3UseCnt = 0L;
|
||||
|
||||
@Column(name = "model_type")
|
||||
@Enumerated(EnumType.STRING)
|
||||
private ModelType modelType;
|
||||
|
||||
@@ -83,6 +83,9 @@ public class ModelTrainJobEntity {
|
||||
@Column(name = "current_epoch")
|
||||
private Integer currentEpoch;
|
||||
|
||||
@Column(name = "job_type")
|
||||
private String jobType;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
|
||||
@@ -29,9 +29,28 @@ public interface HyperParamRepositoryCustom {
|
||||
|
||||
Optional<ModelHyperParamEntity> findHyperParamByHyperVer(String hyperVer);
|
||||
|
||||
/**
|
||||
* 하이퍼 파라미터 상세조회
|
||||
*
|
||||
* @param uuid
|
||||
* @return
|
||||
*/
|
||||
Optional<ModelHyperParamEntity> findHyperParamByUuid(UUID uuid);
|
||||
|
||||
/**
|
||||
* 하이퍼 파라미터 목록 조회
|
||||
*
|
||||
* @param model
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||
|
||||
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||
/**
|
||||
* 하이퍼 파라미터 모델타입으로 조회
|
||||
*
|
||||
* @param modelType
|
||||
* @return
|
||||
*/
|
||||
List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import com.querydsl.core.BooleanBuilder;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.core.types.dsl.NumberExpression;
|
||||
import com.querydsl.jpa.impl.JPAQuery;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.time.ZoneId;
|
||||
@@ -82,7 +81,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.uuid.eq(uuid)))
|
||||
.where(modelHyperParamEntity.uuid.eq(uuid))
|
||||
.fetchOne());
|
||||
}
|
||||
|
||||
@@ -91,10 +90,12 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
Pageable pageable = req.toPageable();
|
||||
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
|
||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||
|
||||
if (model != null) {
|
||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||
}
|
||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||
|
||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||
// 버전
|
||||
@@ -118,13 +119,6 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
}
|
||||
}
|
||||
|
||||
NumberExpression<Long> totalUseCnt =
|
||||
modelHyperParamEntity
|
||||
.m1UseCnt
|
||||
.coalesce(0L)
|
||||
.add(modelHyperParamEntity.m2UseCnt.coalesce(0L))
|
||||
.add(modelHyperParamEntity.m3UseCnt.coalesce(0L));
|
||||
|
||||
JPAQuery<HyperParamDto.List> query =
|
||||
queryFactory
|
||||
.select(
|
||||
@@ -135,11 +129,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
modelHyperParamEntity.hyperVer,
|
||||
modelHyperParamEntity.createdDttm,
|
||||
modelHyperParamEntity.lastUsedDttm,
|
||||
modelHyperParamEntity.memo,
|
||||
modelHyperParamEntity.m1UseCnt,
|
||||
modelHyperParamEntity.m2UseCnt,
|
||||
modelHyperParamEntity.m3UseCnt,
|
||||
totalUseCnt.as("totalUseCnt")))
|
||||
modelHyperParamEntity.memo))
|
||||
.from(modelHyperParamEntity)
|
||||
.where(builder);
|
||||
|
||||
@@ -165,8 +155,6 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
? modelHyperParamEntity.lastUsedDttm.asc()
|
||||
: modelHyperParamEntity.lastUsedDttm.desc());
|
||||
|
||||
case "totalUseCnt" -> query.orderBy(asc ? totalUseCnt.asc() : totalUseCnt.desc());
|
||||
|
||||
default -> query.orderBy(modelHyperParamEntity.createdDttm.desc());
|
||||
}
|
||||
}
|
||||
@@ -187,7 +175,7 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||
public List<ModelHyperParamEntity> getHyperParamByType(ModelType modelType) {
|
||||
return queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
|
||||
@@ -52,7 +52,14 @@ public class DockerTrainService {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
/**
|
||||
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||
*
|
||||
* @param req
|
||||
* @param containerName
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerRunCommand(containerName, req);
|
||||
@@ -267,8 +274,7 @@ public class DockerTrainService {
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
addArg(c, "--batch-size", req.getBatchSize());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds());
|
||||
// addArg(c, "--gpus", req.getGpus());
|
||||
addArg(c, "--gpu-ids", req.getGpuIds()); // null
|
||||
addArg(c, "--lr", req.getLearningRate());
|
||||
addArg(c, "--backbone", req.getBackbone());
|
||||
addArg(c, "--epochs", req.getEpochs());
|
||||
@@ -342,6 +348,14 @@ public class DockerTrainService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환
|
||||
*
|
||||
* @param containerName
|
||||
* @param req
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public TrainRunResult runEvalSync(String containerName, EvalRunRequest req) throws Exception {
|
||||
|
||||
List<String> cmd = buildDockerEvalCommand(containerName, req);
|
||||
|
||||
@@ -48,20 +48,8 @@ public class ModelTestMetricsJobService {
|
||||
@Value("${file.pt-path}")
|
||||
private String ptPathDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() throws IOException {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
/** 결과 csv 파일 정보 등록 */
|
||||
public void findTestValidMetricCsvFiles() {
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTestMetricsJobCoreService.getTestMetricSaveNotYetModelIds();
|
||||
|
||||
@@ -36,20 +36,8 @@ public class ModelTrainMetricsJobService {
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean isLocalProfile() {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
/** 결과 csv 파일 정보 등록 */
|
||||
public void findTrainValidMetricCsvFiles() {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
List<ResponsePathDto> modelIds =
|
||||
modelTrainMetricsJobCoreService.getTrainMetricSaveNotYetModelIds();
|
||||
|
||||
@@ -23,6 +23,14 @@ public class TestJobService {
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
/**
|
||||
* 실행 예약 (QUEUE 등록)
|
||||
*
|
||||
* @param modelId
|
||||
* @param uuid
|
||||
* @param epoch
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
|
||||
@@ -58,6 +66,11 @@ public class TestJobService {
|
||||
return jobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 취소
|
||||
*
|
||||
* @param modelId
|
||||
*/
|
||||
@Transactional
|
||||
public void cancel(Long modelId) {
|
||||
|
||||
|
||||
@@ -47,7 +47,12 @@ public class TrainJobService {
|
||||
return modelTrainMngCoreService.findModelIdByUuid(uuid);
|
||||
}
|
||||
|
||||
/** 실행 예약 (QUEUE 등록) */
|
||||
/**
|
||||
* 실행 예약 (QUEUE 등록)
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
@@ -139,6 +144,13 @@ public class TrainJobService {
|
||||
modelTrainMngCoreService.markStopped(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 이어하기
|
||||
*
|
||||
* @param modelId 모델 id
|
||||
* @param mode NONE 새로 시작, REQUIRE 이어하기
|
||||
* @return
|
||||
*/
|
||||
private Long createNextAttempt(Long modelId, ResumeMode mode) {
|
||||
|
||||
ModelTrainMngDto.Basic master = modelTrainMngCoreService.findModelById(modelId);
|
||||
@@ -189,6 +201,12 @@ public class TrainJobService {
|
||||
REQUIRE // 이어하기
|
||||
}
|
||||
|
||||
/**
|
||||
* 이어하기 체크포인트 탐지해서 resumeFrom 세팅
|
||||
*
|
||||
* @param paramsJson
|
||||
* @return
|
||||
*/
|
||||
public String findResumeFromOrNull(Map<String, Object> paramsJson) {
|
||||
if (paramsJson == null) return null;
|
||||
|
||||
@@ -230,6 +248,12 @@ public class TrainJobService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습에 필요한 데이터셋 파일을 임시폴더 하나에 합치기
|
||||
*
|
||||
* @param modelUuid
|
||||
* @return
|
||||
*/
|
||||
@Transactional
|
||||
public UUID createTmpFile(UUID modelUuid) {
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
|
||||
@@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
/** job 실행 */
|
||||
@Log4j2
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@@ -54,6 +55,8 @@ public class TrainJobWorker {
|
||||
String containerName =
|
||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||
|
||||
String type = isEval ? "TEST" : "TRAIN";
|
||||
|
||||
Integer totalEpoch = null;
|
||||
if (params.containsKey("totalEpoch")) {
|
||||
if (params.get("totalEpoch") != null) {
|
||||
@@ -61,12 +64,15 @@ public class TrainJobWorker {
|
||||
}
|
||||
}
|
||||
log.info("[JOB] markRunning start jobId={}, containerName={}", jobId, containerName);
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
||||
// 실행 시작 처리
|
||||
modelTrainJobCoreService.markRunning(
|
||||
jobId, containerName, null, "TRAIN_WORKER", totalEpoch, type);
|
||||
log.info("[JOB] markRunning done jobId={}", jobId);
|
||||
try {
|
||||
TrainRunResult result;
|
||||
|
||||
if (isEval) {
|
||||
// step2 진행중 처리
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
String uuid = String.valueOf(params.get("uuid"));
|
||||
int epoch = (int) params.get("epoch");
|
||||
@@ -81,11 +87,14 @@ public class TrainJobWorker {
|
||||
evalReq.setOutputFolder(outputFolder);
|
||||
log.info("[JOB] selected test epoch={}", epoch);
|
||||
|
||||
// 도커 실행 후 로그 수집
|
||||
result = dockerTrainService.runEvalSync(containerName, evalReq);
|
||||
|
||||
} else {
|
||||
// step1 진행중 처리
|
||||
modelTrainMngCoreService.markStep1InProgress(modelId, jobId);
|
||||
TrainRunRequest trainReq = toTrainRunRequest(params);
|
||||
// 도커 실행 후 로그 수집
|
||||
result = dockerTrainService.runTrainSync(trainReq, containerName);
|
||||
}
|
||||
|
||||
@@ -99,24 +108,31 @@ public class TrainJobWorker {
|
||||
}
|
||||
|
||||
if (result.getExitCode() == 0) {
|
||||
// 성공 처리
|
||||
modelTrainJobCoreService.markSuccess(jobId, result.getExitCode());
|
||||
|
||||
if (isEval) {
|
||||
// step2 완료처리
|
||||
modelTrainMngCoreService.markStep2Success(modelId);
|
||||
// 결과 csv 파일 정보 등록
|
||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1Success(modelId);
|
||||
// 결과 csv 파일 정보 등록
|
||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||
}
|
||||
|
||||
} else {
|
||||
String failMsg = result.getStatus() + "\n" + result.getLogs();
|
||||
// 실패 처리
|
||||
modelTrainJobCoreService.markFailed(
|
||||
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());
|
||||
|
||||
if (isEval) {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markStep2Error(modelId, "exit=" + result.getExitCode());
|
||||
} else {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markError(modelId, "exit=" + result.getExitCode());
|
||||
}
|
||||
}
|
||||
@@ -125,8 +141,10 @@ public class TrainJobWorker {
|
||||
modelTrainJobCoreService.markFailed(jobId, null, e.getMessage());
|
||||
|
||||
if ("EVAL".equals(params.get("jobType"))) {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markStep2Error(modelId, e.getMessage());
|
||||
} else {
|
||||
// 오류 정보 등록
|
||||
modelTrainMngCoreService.markError(modelId, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user