feat/training_260202 #45
@@ -1,6 +1,6 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ public class ModelTrainJobCoreService {
|
|||||||
|
|
||||||
/** 실행 시작 처리 */
|
/** 실행 시작 처리 */
|
||||||
@Transactional
|
@Transactional
|
||||||
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
|
public void markRunning(
|
||||||
|
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
|
||||||
ModelTrainJobEntity job =
|
ModelTrainJobEntity job =
|
||||||
modelTrainJobRepository
|
modelTrainJobRepository
|
||||||
.findById(jobId)
|
.findById(jobId)
|
||||||
@@ -59,6 +60,10 @@ public class ModelTrainJobCoreService {
|
|||||||
job.setStartedDttm(ZonedDateTime.now());
|
job.setStartedDttm(ZonedDateTime.now());
|
||||||
job.setLockedDttm(ZonedDateTime.now());
|
job.setLockedDttm(ZonedDateTime.now());
|
||||||
job.setLockedBy(lockedBy);
|
job.setLockedBy(lockedBy);
|
||||||
|
|
||||||
|
if (totalEpoch != null) {
|
||||||
|
job.setTotalEpoch(totalEpoch);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 성공 처리 */
|
/** 성공 처리 */
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.kamco.cd.training.postgres.core;
|
package com.kamco.cd.training.postgres.core;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.repository.schedule.ModelTrainMetricsJobRepository;
|
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ public class ModelTrainJobEntity {
|
|||||||
@Column(name = "locked_by", length = 100)
|
@Column(name = "locked_by", length = 100)
|
||||||
private String lockedBy;
|
private String lockedBy;
|
||||||
|
|
||||||
|
@Column(name = "total_epoch")
|
||||||
|
private Integer totalEpoch;
|
||||||
|
|
||||||
|
@Column(name = "current_epoch")
|
||||||
|
private Integer currentEpoch;
|
||||||
|
|
||||||
public ModelTrainJobDto toDto() {
|
public ModelTrainJobDto toDto() {
|
||||||
return new ModelTrainJobDto(
|
return new ModelTrainJobDto(
|
||||||
this.id,
|
this.id,
|
||||||
@@ -90,6 +96,8 @@ public class ModelTrainJobEntity {
|
|||||||
this.paramsJson,
|
this.paramsJson,
|
||||||
this.queuedDttm,
|
this.queuedDttm,
|
||||||
this.startedDttm,
|
this.startedDttm,
|
||||||
this.finishedDttm);
|
this.finishedDttm,
|
||||||
|
this.totalEpoch,
|
||||||
|
this.currentEpoch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
|||||||
modelHyperParamEntity.gpuCnt,
|
modelHyperParamEntity.gpuCnt,
|
||||||
modelHyperParamEntity.learningRate,
|
modelHyperParamEntity.learningRate,
|
||||||
modelHyperParamEntity.backbone,
|
modelHyperParamEntity.backbone,
|
||||||
modelHyperParamEntity.epochCnt,
|
modelConfigEntity.epochCount,
|
||||||
modelHyperParamEntity.trainNumWorkers,
|
modelHyperParamEntity.trainNumWorkers,
|
||||||
modelHyperParamEntity.valNumWorkers,
|
modelHyperParamEntity.valNumWorkers,
|
||||||
modelHyperParamEntity.testNumWorkers,
|
modelHyperParamEntity.testNumWorkers,
|
||||||
@@ -135,7 +135,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
|||||||
modelHyperParamEntity.saturationRange,
|
modelHyperParamEntity.saturationRange,
|
||||||
modelHyperParamEntity.hueDelta,
|
modelHyperParamEntity.hueDelta,
|
||||||
Expressions.nullExpression(Integer.class),
|
Expressions.nullExpression(Integer.class),
|
||||||
Expressions.nullExpression(String.class)))
|
Expressions.nullExpression(String.class),
|
||||||
|
modelHyperParamEntity.uuid))
|
||||||
.from(modelMasterEntity)
|
.from(modelMasterEntity)
|
||||||
.leftJoin(modelHyperParamEntity)
|
.leftJoin(modelHyperParamEntity)
|
||||||
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||||
import org.springframework.data.jpa.repository.JpaRepository;
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
|
||||||
import org.springframework.data.jpa.repository.JpaRepository;
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.postgres.repository.schedule;
|
package com.kamco.cd.training.postgres.repository.train;
|
||||||
|
|
||||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||||
|
|
||||||
@@ -20,4 +20,6 @@ public class ModelTrainJobDto {
|
|||||||
private ZonedDateTime queuedDttm;
|
private ZonedDateTime queuedDttm;
|
||||||
private ZonedDateTime startedDttm;
|
private ZonedDateTime startedDttm;
|
||||||
private ZonedDateTime finishedDttm;
|
private ZonedDateTime finishedDttm;
|
||||||
|
private Integer totalEpoch;
|
||||||
|
private Integer currentEpoch;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,11 +82,17 @@ public class TrainRunRequest {
|
|||||||
private Integer timeoutSeconds;
|
private Integer timeoutSeconds;
|
||||||
private String resumeFrom;
|
private String resumeFrom;
|
||||||
|
|
||||||
|
private UUID uuid;
|
||||||
|
|
||||||
public String getDatasetFolder() {
|
public String getDatasetFolder() {
|
||||||
return String.valueOf(datasetFolder);
|
return String.valueOf(this.datasetFolder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getOutputFolder() {
|
public String getOutputFolder() {
|
||||||
return String.valueOf(outputFolder);
|
return String.valueOf(this.outputFolder);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getUuid() {
|
||||||
|
return String.valueOf(this.uuid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
import lombok.extern.log4j.Log4j2;
|
import lombok.extern.log4j.Log4j2;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -57,23 +59,59 @@ public class DockerTrainService {
|
|||||||
Process p = pb.start();
|
Process p = pb.start();
|
||||||
|
|
||||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||||
StringBuilder log = new StringBuilder();
|
StringBuilder logBuilder = new StringBuilder();
|
||||||
|
|
||||||
|
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||||
|
|
||||||
Thread logThread =
|
Thread logThread =
|
||||||
new Thread(
|
new Thread(
|
||||||
() -> {
|
() -> {
|
||||||
try (BufferedReader br =
|
try (BufferedReader br =
|
||||||
new BufferedReader(
|
new BufferedReader(
|
||||||
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||||
|
|
||||||
String line;
|
String line;
|
||||||
while ((line = br.readLine()) != null) {
|
while ((line = br.readLine()) != null) {
|
||||||
synchronized (log) {
|
|
||||||
log.append(line).append('\n');
|
// 1) 로그 누적
|
||||||
|
synchronized (logBuilder) {
|
||||||
|
logBuilder.append(line).append('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) epoch 감지 + DB 업데이트
|
||||||
|
Matcher m = epochPattern.matcher(line);
|
||||||
|
if (m.find()) {
|
||||||
|
int currentEpoch = Integer.parseInt(m.group(1));
|
||||||
|
int totalEpoch = Integer.parseInt(m.group(2));
|
||||||
|
|
||||||
|
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||||
|
|
||||||
|
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||||
|
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||||
|
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||||
|
// currentEpoch, totalEpoch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (Exception ignored) {
|
} catch (Exception e) {
|
||||||
|
log.warn("logThread error: {}", e.toString());
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"train-log-" + containerName);
|
"train-log-" + containerName);
|
||||||
|
// new Thread(
|
||||||
|
// () -> {
|
||||||
|
// try (BufferedReader br =
|
||||||
|
// new BufferedReader(
|
||||||
|
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||||
|
// String line;
|
||||||
|
// while ((line = br.readLine()) != null) {
|
||||||
|
// synchronized (log) {
|
||||||
|
// log.append(line).append('\n');
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } catch (Exception ignored) {
|
||||||
|
// }
|
||||||
|
// },
|
||||||
|
// "train-log-" + containerName);
|
||||||
|
|
||||||
logThread.setDaemon(true);
|
logThread.setDaemon(true);
|
||||||
logThread.start();
|
logThread.start();
|
||||||
@@ -90,8 +128,8 @@ public class DockerTrainService {
|
|||||||
killContainer(containerName);
|
killContainer(containerName);
|
||||||
|
|
||||||
String logs;
|
String logs;
|
||||||
synchronized (log) {
|
synchronized (logBuilder) {
|
||||||
logs = log.toString();
|
logs = logBuilder.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new TrainRunResult(
|
return new TrainRunResult(
|
||||||
@@ -108,8 +146,8 @@ public class DockerTrainService {
|
|||||||
logThread.join(500);
|
logThread.join(500);
|
||||||
|
|
||||||
String logs;
|
String logs;
|
||||||
synchronized (log) {
|
synchronized (logBuilder) {
|
||||||
logs = log.toString();
|
logs = logBuilder.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
|
||||||
@@ -131,7 +169,7 @@ public class DockerTrainService {
|
|||||||
|
|
||||||
// 컨테이너 이름 지정
|
// 컨테이너 이름 지정
|
||||||
c.add("--name");
|
c.add("--name");
|
||||||
c.add(containerName + "-" + req.getOutputFolder().substring(0, 8));
|
c.add(containerName + "-" + req.getUuid().substring(0, 8));
|
||||||
|
|
||||||
// 실행 종료 시 자동 삭제
|
// 실행 종료 시 자동 삭제
|
||||||
c.add("--rm");
|
c.add("--rm");
|
||||||
@@ -183,7 +221,7 @@ public class DockerTrainService {
|
|||||||
c.add("/workspace/change-detection-code/train_wrapper.py");
|
c.add("/workspace/change-detection-code/train_wrapper.py");
|
||||||
|
|
||||||
// ===== 기본 파라미터 =====
|
// ===== 기본 파라미터 =====
|
||||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
|
||||||
addArg(c, "--output-folder", req.getOutputFolder());
|
addArg(c, "--output-folder", req.getOutputFolder());
|
||||||
addArg(c, "--input-size", req.getInputSize());
|
addArg(c, "--input-size", req.getInputSize());
|
||||||
addArg(c, "--crop-size", req.getCropSize());
|
addArg(c, "--crop-size", req.getCropSize());
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.schedule.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.kamco.cd.training.schedule.service;
|
package com.kamco.cd.training.train.service;
|
||||||
|
|
||||||
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
@@ -31,7 +31,7 @@ public class TestJobService {
|
|||||||
|
|
||||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||||
params.put("jobType", "EVAL");
|
params.put("jobType", "EVAL");
|
||||||
params.put("uuid", uuid);
|
params.put("uuid", String.valueOf(uuid));
|
||||||
params.put("epoch", epoch);
|
params.put("epoch", epoch);
|
||||||
|
|
||||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ public class TrainJobService {
|
|||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
|
||||||
paramsMap.put("jobType", "TRAIN");
|
paramsMap.put("jobType", "TRAIN");
|
||||||
|
paramsMap.put("uuid", trainRunRequest.getUuid());
|
||||||
|
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
|
||||||
|
|
||||||
Long jobId =
|
Long jobId =
|
||||||
modelTrainJobCoreService.createQueuedJob(
|
modelTrainJobCoreService.createQueuedJob(
|
||||||
|
|||||||
@@ -47,9 +47,17 @@ public class TrainJobWorker {
|
|||||||
|
|
||||||
boolean isEval = "EVAL".equals(jobType);
|
boolean isEval = "EVAL".equals(jobType);
|
||||||
|
|
||||||
String containerName = (isEval ? "eval-" : "train-") + jobId;
|
String containerName =
|
||||||
|
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
|
||||||
|
|
||||||
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
|
Integer totalEpoch = null;
|
||||||
|
if (params.containsKey("totalEpoch")) {
|
||||||
|
if (params.get("totalEpoch") != null) {
|
||||||
|
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
TrainRunResult result;
|
TrainRunResult result;
|
||||||
|
|||||||
Reference in New Issue
Block a user