Merge pull request 'feat/training_260202' (#45) from feat/training_260202 into develop
Reviewed-on: #45
This commit was merged in pull request #45.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -47,7 +47,8 @@ public class ModelTrainJobCoreService {
|
||||
|
||||
/** 실행 시작 처리 */
|
||||
@Transactional
|
||||
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
|
||||
public void markRunning(
|
||||
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findById(jobId)
|
||||
@@ -59,6 +60,10 @@ public class ModelTrainJobCoreService {
|
||||
job.setStartedDttm(ZonedDateTime.now());
|
||||
job.setLockedDttm(ZonedDateTime.now());
|
||||
job.setLockedBy(lockedBy);
|
||||
|
||||
if (totalEpoch != null) {
|
||||
job.setTotalEpoch(totalEpoch);
|
||||
}
|
||||
}
|
||||
|
||||
/** 성공 처리 */
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.schedule.ModelTrainMetricsJobRepository;
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -78,6 +78,12 @@ public class ModelTrainJobEntity {
|
||||
@Column(name = "locked_by", length = 100)
|
||||
private String lockedBy;
|
||||
|
||||
@Column(name = "total_epoch")
|
||||
private Integer totalEpoch;
|
||||
|
||||
@Column(name = "current_epoch")
|
||||
private Integer currentEpoch;
|
||||
|
||||
public ModelTrainJobDto toDto() {
|
||||
return new ModelTrainJobDto(
|
||||
this.id,
|
||||
@@ -90,6 +96,8 @@ public class ModelTrainJobEntity {
|
||||
this.paramsJson,
|
||||
this.queuedDttm,
|
||||
this.startedDttm,
|
||||
this.finishedDttm);
|
||||
this.finishedDttm,
|
||||
this.totalEpoch,
|
||||
this.currentEpoch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
modelHyperParamEntity.gpuCnt,
|
||||
modelHyperParamEntity.learningRate,
|
||||
modelHyperParamEntity.backbone,
|
||||
modelHyperParamEntity.epochCnt,
|
||||
modelConfigEntity.epochCount,
|
||||
modelHyperParamEntity.trainNumWorkers,
|
||||
modelHyperParamEntity.valNumWorkers,
|
||||
modelHyperParamEntity.testNumWorkers,
|
||||
@@ -135,7 +135,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
modelHyperParamEntity.saturationRange,
|
||||
modelHyperParamEntity.hueDelta,
|
||||
Expressions.nullExpression(Integer.class),
|
||||
Expressions.nullExpression(String.class)))
|
||||
Expressions.nullExpression(String.class),
|
||||
modelHyperParamEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelHyperParamEntity)
|
||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.postgres.repository.schedule;
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
@@ -20,4 +20,6 @@ public class ModelTrainJobDto {
|
||||
private ZonedDateTime queuedDttm;
|
||||
private ZonedDateTime startedDttm;
|
||||
private ZonedDateTime finishedDttm;
|
||||
private Integer totalEpoch;
|
||||
private Integer currentEpoch;
|
||||
}
|
||||
|
||||
@@ -82,11 +82,17 @@ public class TrainRunRequest {
|
||||
private Integer timeoutSeconds;
|
||||
private String resumeFrom;
|
||||
|
||||
private UUID uuid;
|
||||
|
||||
public String getDatasetFolder() {
|
||||
return String.valueOf(datasetFolder);
|
||||
return String.valueOf(this.datasetFolder);
|
||||
}
|
||||
|
||||
public String getOutputFolder() {
|
||||
return String.valueOf(outputFolder);
|
||||
return String.valueOf(this.outputFolder);
|
||||
}
|
||||
|
||||
public String getUuid() {
|
||||
return String.valueOf(this.uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -57,23 +59,59 @@ public class DockerTrainService {
|
||||
Process p = pb.start();
|
||||
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder log = new StringBuilder();
|
||||
StringBuilder logBuilder = new StringBuilder();
|
||||
|
||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
try (BufferedReader br =
|
||||
new BufferedReader(
|
||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
synchronized (log) {
|
||||
log.append(line).append('\n');
|
||||
|
||||
// 1) 로그 누적
|
||||
synchronized (logBuilder) {
|
||||
logBuilder.append(line).append('\n');
|
||||
}
|
||||
|
||||
// 2) epoch 감지 + DB 업데이트
|
||||
Matcher m = epochPattern.matcher(line);
|
||||
if (m.find()) {
|
||||
int currentEpoch = Integer.parseInt(m.group(1));
|
||||
int totalEpoch = Integer.parseInt(m.group(2));
|
||||
|
||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||
|
||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||
// currentEpoch, totalEpoch);
|
||||
}
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
} catch (Exception e) {
|
||||
log.warn("logThread error: {}", e.toString());
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
// new Thread(
|
||||
// () -> {
|
||||
// try (BufferedReader br =
|
||||
// new BufferedReader(
|
||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
// String line;
|
||||
// while ((line = br.readLine()) != null) {
|
||||
// synchronized (log) {
|
||||
// log.append(line).append('\n');
|
||||
// }
|
||||
// }
|
||||
// } catch (Exception ignored) {
|
||||
// }
|
||||
// },
|
||||
// "train-log-" + containerName);
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
@@ -90,8 +128,8 @@ public class DockerTrainService {
|
||||
killContainer(containerName);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
synchronized (logBuilder) {
|
||||
logs = logBuilder.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(
|
||||
@@ -108,8 +146,8 @@ public class DockerTrainService {
|
||||
logThread.join(500);
|
||||
|
||||
String logs;
|
||||
synchronized (log) {
|
||||
logs = log.toString();
|
||||
synchronized (logBuilder) {
|
||||
logs = logBuilder.toString();
|
||||
}
|
||||
|
||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||
@@ -131,7 +169,7 @@ public class DockerTrainService {
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName + "-" + req.getOutputFolder().substring(0, 8));
|
||||
c.add(containerName + "-" + req.getUuid().substring(0, 8));
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
@@ -183,7 +221,7 @@ public class DockerTrainService {
|
||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||
|
||||
// ===== 기본 파라미터 =====
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
addArg(c, "--input-size", req.getInputSize());
|
||||
addArg(c, "--crop-size", req.getCropSize());
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.schedule.service;
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||
import java.io.BufferedReader;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.kamco.cd.training.schedule.service;
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||
import java.io.BufferedReader;
|
||||
@@ -31,7 +31,7 @@ public class TestJobService {
|
||||
|
||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||
params.put("jobType", "EVAL");
|
||||
params.put("uuid", uuid);
|
||||
params.put("uuid", String.valueOf(uuid));
|
||||
params.put("epoch", epoch);
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
@@ -57,6 +57,8 @@ public class TrainJobService {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||
paramsMap.put("jobType", "TRAIN");
|
||||
paramsMap.put("uuid", trainRunRequest.getUuid());
|
||||
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
|
||||
|
||||
Long jobId =
|
||||
modelTrainJobCoreService.createQueuedJob(
|
||||
|
||||
@@ -47,9 +47,17 @@ public class TrainJobWorker {
|
||||
|
||||
boolean isEval = "EVAL".equals(jobType);
|
||||
|
||||
String containerName = (isEval ? "eval-" : "train-") + jobId;
|
||||
String containerName =
|
||||
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
|
||||
Integer totalEpoch = null;
|
||||
if (params.containsKey("totalEpoch")) {
|
||||
if (params.get("totalEpoch") != null) {
|
||||
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
|
||||
}
|
||||
}
|
||||
|
||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
||||
|
||||
try {
|
||||
TrainRunResult result;
|
||||
|
||||
Reference in New Issue
Block a user