Merge pull request 'feat/training_260202' (#45) from feat/training_260202 into develop

Reviewed-on: #45
This commit was merged in pull request #45.
This commit is contained in:
2026-02-12 12:06:04 +09:00
19 changed files with 99 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository;
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;

View File

@@ -47,7 +47,8 @@ public class ModelTrainJobCoreService {
/** 실행 시작 처리 */
@Transactional
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) {
public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
ModelTrainJobEntity job =
modelTrainJobRepository
.findById(jobId)
@@ -59,6 +60,10 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
}
/** 성공 처리 */

View File

@@ -1,6 +1,6 @@
package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.schedule.ModelTrainMetricsJobRepository;
import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;

View File

@@ -78,6 +78,12 @@ public class ModelTrainJobEntity {
@Column(name = "locked_by", length = 100)
private String lockedBy;
@Column(name = "total_epoch")
private Integer totalEpoch;
@Column(name = "current_epoch")
private Integer currentEpoch;
public ModelTrainJobDto toDto() {
return new ModelTrainJobDto(
this.id,
@@ -90,6 +96,8 @@ public class ModelTrainJobEntity {
this.paramsJson,
this.queuedDttm,
this.startedDttm,
this.finishedDttm);
this.finishedDttm,
this.totalEpoch,
this.currentEpoch);
}
}

View File

@@ -103,7 +103,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone,
modelHyperParamEntity.epochCnt,
modelConfigEntity.epochCount,
modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers,
@@ -135,7 +135,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class)))
Expressions.nullExpression(String.class),
modelHyperParamEntity.uuid))
.from(modelMasterEntity)
.leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import org.springframework.data.jpa.repository.JpaRepository;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import org.springframework.data.jpa.repository.JpaRepository;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule;
package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;

View File

@@ -20,4 +20,6 @@ public class ModelTrainJobDto {
private ZonedDateTime queuedDttm;
private ZonedDateTime startedDttm;
private ZonedDateTime finishedDttm;
private Integer totalEpoch;
private Integer currentEpoch;
}

View File

@@ -82,11 +82,17 @@ public class TrainRunRequest {
private Integer timeoutSeconds;
private String resumeFrom;
private UUID uuid;
public String getDatasetFolder() {
return String.valueOf(datasetFolder);
return String.valueOf(this.datasetFolder);
}
public String getOutputFolder() {
return String.valueOf(outputFolder);
return String.valueOf(this.outputFolder);
}
public String getUuid() {
return String.valueOf(this.uuid);
}
}

View File

@@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@@ -57,23 +59,59 @@ public class DockerTrainService {
Process p = pb.start();
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder log = new StringBuilder();
StringBuilder logBuilder = new StringBuilder();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
Thread logThread =
new Thread(
() -> {
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n');
// 1) 로그 누적
synchronized (logBuilder) {
logBuilder.append(line).append('\n');
}
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line);
if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
// TODO 실행중인 에폭 저장 필요하면 만들어야함
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch);
}
}
} catch (Exception ignored) {
} catch (Exception e) {
log.warn("logThread error: {}", e.toString());
}
},
"train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true);
logThread.start();
@@ -90,8 +128,8 @@ public class DockerTrainService {
killContainer(containerName);
String logs;
synchronized (log) {
logs = log.toString();
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(
@@ -108,8 +146,8 @@ public class DockerTrainService {
logThread.join(500);
String logs;
synchronized (log) {
logs = log.toString();
synchronized (logBuilder) {
logs = logBuilder.toString();
}
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
@@ -131,7 +169,7 @@ public class DockerTrainService {
// 컨테이너 이름 지정
c.add("--name");
c.add(containerName + "-" + req.getOutputFolder().substring(0, 8));
c.add(containerName + "-" + req.getUuid().substring(0, 8));
// 실행 종료 시 자동 삭제
c.add("--rm");
@@ -183,7 +221,7 @@ public class DockerTrainService {
c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder());
addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize());

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.schedule.service;
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import java.io.BufferedReader;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.schedule.service;
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import java.io.BufferedReader;

View File

@@ -31,7 +31,7 @@ public class TestJobService {
Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL");
params.put("uuid", uuid);
params.put("uuid", String.valueOf(uuid));
params.put("epoch", epoch);
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;

View File

@@ -57,6 +57,8 @@ public class TrainJobService {
@SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
paramsMap.put("jobType", "TRAIN");
paramsMap.put("uuid", trainRunRequest.getUuid());
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
Long jobId =
modelTrainJobCoreService.createQueuedJob(

View File

@@ -47,9 +47,17 @@ public class TrainJobWorker {
boolean isEval = "EVAL".equals(jobType);
String containerName = (isEval ? "eval-" : "train-") + jobId;
String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER");
Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
}
}
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
try {
TrainRunResult result;