Merge pull request 'feat/training_260202' (#45) from feat/training_260202 into develop

Reviewed-on: #45
This commit was merged in pull request #45.
This commit is contained in:
2026-02-12 12:06:04 +09:00
19 changed files with 99 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.schedule.ModelTestMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;

View File

@@ -47,7 +47,8 @@ public class ModelTrainJobCoreService {
/** 실행 시작 처리 */ /** 실행 시작 처리 */
@Transactional @Transactional
public void markRunning(Long jobId, String containerName, String logPath, String lockedBy) { public void markRunning(
Long jobId, String containerName, String logPath, String lockedBy, Integer totalEpoch) {
ModelTrainJobEntity job = ModelTrainJobEntity job =
modelTrainJobRepository modelTrainJobRepository
.findById(jobId) .findById(jobId)
@@ -59,6 +60,10 @@ public class ModelTrainJobCoreService {
job.setStartedDttm(ZonedDateTime.now()); job.setStartedDttm(ZonedDateTime.now());
job.setLockedDttm(ZonedDateTime.now()); job.setLockedDttm(ZonedDateTime.now());
job.setLockedBy(lockedBy); job.setLockedBy(lockedBy);
if (totalEpoch != null) {
job.setTotalEpoch(totalEpoch);
}
} }
/** 성공 처리 */ /** 성공 처리 */

View File

@@ -1,6 +1,6 @@
package com.kamco.cd.training.postgres.core; package com.kamco.cd.training.postgres.core;
import com.kamco.cd.training.postgres.repository.schedule.ModelTrainMetricsJobRepository; import com.kamco.cd.training.postgres.repository.train.ModelTrainMetricsJobRepository;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;

View File

@@ -78,6 +78,12 @@ public class ModelTrainJobEntity {
@Column(name = "locked_by", length = 100) @Column(name = "locked_by", length = 100)
private String lockedBy; private String lockedBy;
@Column(name = "total_epoch")
private Integer totalEpoch;
@Column(name = "current_epoch")
private Integer currentEpoch;
public ModelTrainJobDto toDto() { public ModelTrainJobDto toDto() {
return new ModelTrainJobDto( return new ModelTrainJobDto(
this.id, this.id,
@@ -90,6 +96,8 @@ public class ModelTrainJobEntity {
this.paramsJson, this.paramsJson,
this.queuedDttm, this.queuedDttm,
this.startedDttm, this.startedDttm,
this.finishedDttm); this.finishedDttm,
this.totalEpoch,
this.currentEpoch);
} }
} }

View File

@@ -103,7 +103,7 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.gpuCnt, modelHyperParamEntity.gpuCnt,
modelHyperParamEntity.learningRate, modelHyperParamEntity.learningRate,
modelHyperParamEntity.backbone, modelHyperParamEntity.backbone,
modelHyperParamEntity.epochCnt, modelConfigEntity.epochCount,
modelHyperParamEntity.trainNumWorkers, modelHyperParamEntity.trainNumWorkers,
modelHyperParamEntity.valNumWorkers, modelHyperParamEntity.valNumWorkers,
modelHyperParamEntity.testNumWorkers, modelHyperParamEntity.testNumWorkers,
@@ -135,7 +135,8 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
modelHyperParamEntity.saturationRange, modelHyperParamEntity.saturationRange,
modelHyperParamEntity.hueDelta, modelHyperParamEntity.hueDelta,
Expressions.nullExpression(Integer.class), Expressions.nullExpression(Integer.class),
Expressions.nullExpression(String.class))) Expressions.nullExpression(String.class),
modelHyperParamEntity.uuid))
.from(modelMasterEntity) .from(modelMasterEntity)
.leftJoin(modelHyperParamEntity) .leftJoin(modelHyperParamEntity)
.on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId)) .on(modelHyperParamEntity.id.eq(modelMasterEntity.hyperParamId))

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.JpaRepository;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import java.util.List; import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity; import com.kamco.cd.training.postgres.entity.ModelMetricsTrainEntity;
import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.JpaRepository;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import java.util.List; import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.postgres.repository.schedule; package com.kamco.cd.training.postgres.repository.train;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;

View File

@@ -20,4 +20,6 @@ public class ModelTrainJobDto {
private ZonedDateTime queuedDttm; private ZonedDateTime queuedDttm;
private ZonedDateTime startedDttm; private ZonedDateTime startedDttm;
private ZonedDateTime finishedDttm; private ZonedDateTime finishedDttm;
private Integer totalEpoch;
private Integer currentEpoch;
} }

View File

@@ -82,11 +82,17 @@ public class TrainRunRequest {
private Integer timeoutSeconds; private Integer timeoutSeconds;
private String resumeFrom; private String resumeFrom;
private UUID uuid;
public String getDatasetFolder() { public String getDatasetFolder() {
return String.valueOf(datasetFolder); return String.valueOf(this.datasetFolder);
} }
public String getOutputFolder() { public String getOutputFolder() {
return String.valueOf(outputFolder); return String.valueOf(this.outputFolder);
}
public String getUuid() {
return String.valueOf(this.uuid);
} }
} }

View File

@@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.log4j.Log4j2; import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -57,23 +59,59 @@ public class DockerTrainService {
Process p = pb.start(); Process p = pb.start();
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게) // 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
StringBuilder log = new StringBuilder(); StringBuilder logBuilder = new StringBuilder();
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
Thread logThread = Thread logThread =
new Thread( new Thread(
() -> { () -> {
try (BufferedReader br = try (BufferedReader br =
new BufferedReader( new BufferedReader(
new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
String line; String line;
while ((line = br.readLine()) != null) { while ((line = br.readLine()) != null) {
synchronized (log) {
log.append(line).append('\n'); // 1) 로그 누적
synchronized (logBuilder) {
logBuilder.append(line).append('\n');
}
// 2) epoch 감지 + DB 업데이트
Matcher m = epochPattern.matcher(line);
if (m.find()) {
int currentEpoch = Integer.parseInt(m.group(1));
int totalEpoch = Integer.parseInt(m.group(2));
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
// TODO 실행중인 에폭 저장 필요하면 만들어야함
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
// currentEpoch, totalEpoch);
} }
} }
} catch (Exception ignored) { } catch (Exception e) {
log.warn("logThread error: {}", e.toString());
} }
}, },
"train-log-" + containerName); "train-log-" + containerName);
// new Thread(
// () -> {
// try (BufferedReader br =
// new BufferedReader(
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
// String line;
// while ((line = br.readLine()) != null) {
// synchronized (log) {
// log.append(line).append('\n');
// }
// }
// } catch (Exception ignored) {
// }
// },
// "train-log-" + containerName);
logThread.setDaemon(true); logThread.setDaemon(true);
logThread.start(); logThread.start();
@@ -90,8 +128,8 @@ public class DockerTrainService {
killContainer(containerName); killContainer(containerName);
String logs; String logs;
synchronized (log) { synchronized (logBuilder) {
logs = log.toString(); logs = logBuilder.toString();
} }
return new TrainRunResult( return new TrainRunResult(
@@ -108,8 +146,8 @@ public class DockerTrainService {
logThread.join(500); logThread.join(500);
String logs; String logs;
synchronized (log) { synchronized (logBuilder) {
logs = log.toString(); logs = logBuilder.toString();
} }
return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs); return new TrainRunResult(null, containerName, exit, exit == 0 ? "SUCCESS" : "FAILED", logs);
@@ -131,7 +169,7 @@ public class DockerTrainService {
// 컨테이너 이름 지정 // 컨테이너 이름 지정
c.add("--name"); c.add("--name");
c.add(containerName + "-" + req.getOutputFolder().substring(0, 8)); c.add(containerName + "-" + req.getUuid().substring(0, 8));
// 실행 종료 시 자동 삭제 // 실행 종료 시 자동 삭제
c.add("--rm"); c.add("--rm");
@@ -183,7 +221,7 @@ public class DockerTrainService {
c.add("/workspace/change-detection-code/train_wrapper.py"); c.add("/workspace/change-detection-code/train_wrapper.py");
// ===== 기본 파라미터 ===== // ===== 기본 파라미터 =====
addArg(c, "--dataset-folder", req.getDatasetFolder()); addArg(c, "--dataset-folder", "4BDBBDF99D04477A927CC9EBA760B845" /*req.getDatasetFolder()*/);
addArg(c, "--output-folder", req.getOutputFolder()); addArg(c, "--output-folder", req.getOutputFolder());
addArg(c, "--input-size", req.getInputSize()); addArg(c, "--input-size", req.getInputSize());
addArg(c, "--crop-size", req.getCropSize()); addArg(c, "--crop-size", req.getCropSize());

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.schedule.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
import java.io.BufferedReader; import java.io.BufferedReader;

View File

@@ -1,4 +1,4 @@
package com.kamco.cd.training.schedule.service; package com.kamco.cd.training.train.service;
import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService; import com.kamco.cd.training.postgres.core.ModelTrainMetricsJobCoreService;
import java.io.BufferedReader; import java.io.BufferedReader;

View File

@@ -31,7 +31,7 @@ public class TestJobService {
Map<String, Object> params = new java.util.LinkedHashMap<>(); Map<String, Object> params = new java.util.LinkedHashMap<>();
params.put("jobType", "EVAL"); params.put("jobType", "EVAL");
params.put("uuid", uuid); params.put("uuid", String.valueOf(uuid));
params.put("epoch", epoch); params.put("epoch", epoch);
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1; int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;

View File

@@ -57,6 +57,8 @@ public class TrainJobService {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class); Map<String, Object> paramsMap = objectMapper.convertValue(trainRunRequest, Map.class);
paramsMap.put("jobType", "TRAIN"); paramsMap.put("jobType", "TRAIN");
paramsMap.put("uuid", trainRunRequest.getUuid());
paramsMap.put("totalEpoch", trainRunRequest.getEpochs());
Long jobId = Long jobId =
modelTrainJobCoreService.createQueuedJob( modelTrainJobCoreService.createQueuedJob(

View File

@@ -47,9 +47,17 @@ public class TrainJobWorker {
boolean isEval = "EVAL".equals(jobType); boolean isEval = "EVAL".equals(jobType);
String containerName = (isEval ? "eval-" : "train-") + jobId; String containerName =
(isEval ? "eval-" : "train-") + jobId + "-" + params.get("uuid").toString().substring(0, 8);
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER"); Integer totalEpoch = null;
if (params.containsKey("totalEpoch")) {
if (params.get("totalEpoch") != null) {
totalEpoch = Integer.parseInt(params.get("totalEpoch").toString());
}
}
modelTrainJobCoreService.markRunning(jobId, containerName, null, "TRAIN_WORKER", totalEpoch);
try { try {
TrainRunResult result; TrainRunResult result;