diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java index 4e7450f..6ba1e56 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTestMetricsJobCoreService.java @@ -1,6 +1,6 @@ package com.kamco.cd.training.postgres.core; -import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository; +import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository; import java.util.List; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java index 350a248..4a2ce5d 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainJobCoreService.java @@ -47,7 +47,8 @@ public class ModelTrainJobCoreService { /** 실행 시작 처리 */ @Transactional - public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) { + public void markRunning( + Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) { ModelTrainJobEntity job = modelTrainJobRepository .findById(jobId) @@ -59,6 +60,10 @@ public class ModelTrainJobCoreService { job.setStartedDttm(ZonedDateTime.now()); job.setLockedDttm(ZonedDateTime.now()); job.setLockedBy(lockedBy); + + if (totalEpoch != null) { + job.setTotalEpoch(totalEpoch); + } } /** 성공 처리 */ diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java index 3c6c1f2..5692017 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMetricsJobCoreService.java @@ -1,6 +1,6 @@ package com.kamco.cd.training.postgres.core; -import com.kamco.cd.training.postgres.repository.schedule.ModelTrainMetricsJobRepository; +import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository; import java.util.List; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; diff --git a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java index 23c11e0..4be89a8 100644 --- a/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java +++ b/src/main/java/com/kamco/cd/training/postgres/entity/ModelTrainJobEntity.java @@ -78,6 +78,12 @@ public class ModelTrainJobEntity { @Column(name = "locked_by", length = 100) private String lockedBy; + @Column(name = "total_epoch") + private Integer totalEpoch; + + @Column(name = "current_epoch") + private Integer currentEpoch; + public ModelTrainJobDto toDto() { return new ModelTrainJobDto( this.id, @@ -90,6 +96,8 @@ public class ModelTrainJobEntity { this.paramsJson, this.queuedDttm, this.startedDttm, - this.finishedDttm); + this.finishedDttm, + this.totalEpoch, + this.currentEpoch); } } diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java index 2fe22be..547d837 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/model/ModelMngRepositoryImpl.java @@ -103,7 +103,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { modelHyperParamEntity.gpuCnt, modelHyperParamEntity.learningRate, modelHyperParamEntity.backbone, - modelHyperParamEntity.epochCnt, + modelConfigEntity.epochCount, modelHyperParamEntity.trainNumWorkers, modelHyperParamEntity.valNumWorkers, modelHyperParamEntity.testNumWorkers, @@ -135,7 +135,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom { modelHyperParamEntity.saturationRange, modelHyperParamEntity.hueDelta, Expressions.nullExpression(Integer.class), - Expressions.nullExpression(String.class))) + Expressions.nullExpression(String.class), + modelHyperParamEntity.uuid)) .from(modelMasterEntity) .leftJoin(modelHyperParamEntity) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepository.java similarity index 84% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepository.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepository.java index 48f8e67..d0945fa 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepository.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepository.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import org.springframework.data.jpa.repository.JpaRepository; diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java similarity index 81% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java index 5a34eca..bd993e1 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryCustom.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import java.util.List; diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java similarity index 97% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java index d30179f..7804c52 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTestMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTestMetricsJobRepositoryImpl.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepository.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepository.java similarity index 85% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepository.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepository.java index 9397e15..2b58eab 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepository.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepository.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity; import org.springframework.data.jpa.repository.JpaRepository; diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java similarity index 84% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryCustom.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java index 7a8c681..a10caa8 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryCustom.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import java.util.List; diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java similarity index 97% rename from src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryImpl.java rename to src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java index 78bbda6..c20bc73 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/schedule/ModelTrainMetricsJobRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/train/ModelTrainMetricsJobRepositoryImpl.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.postgres.repository.schedule; +package com.kamco.cd.training.postgres.repository.train; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; diff --git a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java index f9d0004..9545ec4 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java +++ b/src/main/java/com/kamco/cd/training/train/dto/ModelTrainJobDto.java @@ -20,4 +20,6 @@ public class ModelTrainJobDto { private ZonedDateTime queuedDttm; private ZonedDateTime startedDttm; private ZonedDateTime finishedDttm; + private Integer totalEpoch; + private Integer currentEpoch; } diff --git a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java index 1e1974c..e294ce7 100644 --- a/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java +++ b/src/main/java/com/kamco/cd/training/train/dto/TrainRunRequest.java @@ -82,11 +82,17 @@ public class TrainRunRequest { private Integer timeoutSeconds; private String resumeFrom; + private UUID uuid; + public String getDatasetFolder() { - return String.valueOf(datasetFolder); + return String.valueOf(this.datasetFolder); } public String getOutputFolder() { - return String.valueOf(outputFolder); + return String.valueOf(this.outputFolder); + } + + public String getUuid() { + return String.valueOf(this.uuid); } } diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 4949047..be9486b 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import lombok.extern.log4j.Log4j2; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -57,23 +59,59 @@ public class DockerTrainService { Process p = pb.start(); // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) - StringBuilder log = new StringBuilder(); + StringBuilder logBuilder = new StringBuilder(); + + Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b"); + Thread logThread = new Thread( () -> { try (BufferedReader br = new BufferedReader( new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + String line; while ((line = br.readLine()) != null) { - synchronized (log) { - log.append(line).append('\n'); + + // 1) 로그 누적 + synchronized (logBuilder) { + logBuilder.append(line).append('\n'); + } + + // 2) epoch 감지 + DB 업데이트 + Matcher m = epochPattern.matcher(line); + if (m.find()) { + int currentEpoch = Integer.parseInt(m.group(1)); + int totalEpoch = Integer.parseInt(m.group(2)); + + log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch); + + // TODO 실행중인 에폭 저장 필요하면 만들어야함 + // TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..? + // modelTrainMngCoreService.updateCurrentEpoch(modelId, + // currentEpoch, totalEpoch); } } - } catch (Exception ignored) { + } catch (Exception e) { + log.warn("logThread error: {}", e.toString()); } }, "train-log-" + containerName); + // new Thread( + // () -> { + // try (BufferedReader br = + // new BufferedReader( + // new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { + // String line; + // while ((line = br.readLine()) != null) { + // synchronized (log) { + // log.append(line).append('\n'); + // } + // } + // } catch (Exception ignored) { + // } + // }, + // "train-log-" + containerName); logThread.setDaemon(true); logThread.start(); @@ -90,8 +128,8 @@ public class DockerTrainService { killContainer(containerName); String logs; - synchronized (log) { - logs = log.toString(); + synchronized (logBuilder) { + logs = logBuilder.toString(); } return new TrainRunResult( @@ -108,8 +146,8 @@ public class DockerTrainService { logThread.join(500); String logs; - synchronized (log) { - logs = log.toString(); + synchronized (logBuilder) { + logs = logBuilder.toString(); } return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs); @@ -131,7 +169,7 @@ public class DockerTrainService { // 컨테이너 이름 지정 c.add("--name"); - c.add(containerName + "-" + req.getOutputFolder().substring(0, 8)); + c.add(containerName + "-" + req.getUuid().substring(0, 8)); // 실행 종료 시 자동 삭제 c.add("--rm"); @@ -183,7 +221,7 @@ public class DockerTrainService { c.add("/workspace/change-detection-code/train_wrapper.py"); // ===== 기본 파라미터 ===== - addArg(c, "--dataset-folder", req.getDatasetFolder()); + addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/); addArg(c, "--output-folder", req.getOutputFolder()); addArg(c, "--input-size", req.getInputSize()); addArg(c, "--crop-size", req.getCropSize()); diff --git a/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java similarity index 98% rename from src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java rename to src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java index 15e5011..c5936bc 100644 --- a/src/main/java/com/kamco/cd/training/schedule/service/ModelTestMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTestMetricsJobService.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.schedule.service; +package com.kamco.cd.training.train.service; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import java.io.BufferedReader; diff --git a/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java similarity index 98% rename from src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java rename to src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java index ecb7c27..319a5fd 100644 --- a/src/main/java/com/kamco/cd/training/schedule/service/ModelTrainMetricsJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/ModelTrainMetricsJobService.java @@ -1,4 +1,4 @@ -package com.kamco.cd.training.schedule.service; +package com.kamco.cd.training.train.service; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import java.io.BufferedReader; diff --git a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java index e90cf1c..e7b5dfa 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TestJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TestJobService.java @@ -31,7 +31,7 @@ public class TestJobService { Map params = new java.util.LinkedHashMap<>(); params.put("jobType", "EVAL"); - params.put("uuid", uuid); + params.put("uuid", String.valueOf(uuid)); params.put("epoch", epoch); int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java index a5816b5..ab53f79 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobService.java @@ -57,6 +57,8 @@ public class TrainJobService { @SuppressWarnings("unchecked") Map paramsMap = objectMapper.convertValue(trainRunRequest, Map.class); paramsMap.put("jobType", "TRAIN"); + paramsMap.put("uuid", trainRunRequest.getUuid()); + paramsMap.put("totalEpoch", trainRunRequest.getEpochs()); Long jobId = modelTrainJobCoreService.createQueuedJob( diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 2e4a094..afa2268 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -47,9 +47,17 @@ public class TrainJobWorker { boolean isEval = "EVAL".equals(jobType); - String containerName = (isEval ? "eval-" : "train-") + jobId; + String containerName = + (isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8); - modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER"); + Integer totalEpoch = null; + if (params.containsKey("totalEpoch")) { + if (params.get("totalEpoch") != null) { + totalEpoch = Integer.parseInt(params.get("totalEpoch").toString()); + } + } + + modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch); try { TrainRunResult result;