feat/training_260202 #125

Merged
teddy merged 2 commits from feat/training_260202 into develop 2026-02-20 12:23:37 +09:00
7 changed files with 304 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.model.service;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
import com.kamco.cd.training.model.dto.ModelConfigDto;
@@ -89,7 +90,7 @@ public class ModelTrainDetailService {
datasetReq.setIds(datasetIds);
datasetReq.setModelNo(modelInfo.getModelNo());
if (modelInfo.getModelNo().equals("G1")) {
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
} else {
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);

View File

@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.time.ZonedDateTime;
import java.util.ArrayList;
@@ -519,4 +520,34 @@ public class ModelTrainMngCoreService {
public Long findModelStep1InProgressCnt() {
return modelMngRepository.findModelStep1InProgressCnt();
}
/**
* train 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTrainPath(modelId);
}
/**
* validation 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
return modelDatasetMapRepository.findDatasetValPath(modelId);
}
/**
* test 링크할 파일 경로
*
* @param modelId
* @return
*/
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
return modelDatasetMapRepository.findDatasetTestPath(modelId);
}
}

View File

@@ -1,8 +1,15 @@
package com.kamco.cd.training.postgres.repository.model;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.util.List;
public interface ModelDatasetMappRepositoryCustom {
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId);
List<ModelTrainLinkDto> findDatasetValPath(Long modelId);
List<ModelTrainLinkDto> findDatasetTestPath(Long modelId);
}

View File

@@ -1,8 +1,16 @@
package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.querydsl.core.types.Projections;
import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List;
import lombok.RequiredArgsConstructor;
@@ -22,4 +30,133 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
.where(modelDatasetMappEntity.modelUid.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetObjEntity.targetClassCd,
datasetObjEntity.comparePath,
datasetObjEntity.targetPath,
datasetObjEntity.labelPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetObjEntity)
.on(
datasetObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetValObjEntity.targetClassCd,
datasetValObjEntity.comparePath,
datasetValObjEntity.targetPath,
datasetValObjEntity.labelPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetValObjEntity)
.on(
datasetValObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
@Override
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
return queryFactory
.select(
Projections.constructor(
ModelTrainLinkDto.class,
modelMasterEntity.id,
modelMasterEntity.trainType,
modelMasterEntity.modelNo,
modelDatasetMappEntity.datasetUid,
datasetTestObjEntity.targetClassCd,
datasetTestObjEntity.comparePath,
datasetTestObjEntity.targetPath,
datasetTestObjEntity.labelPath,
datasetEntity.uid))
.from(modelMasterEntity)
.leftJoin(modelDatasetMappEntity)
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
.leftJoin(datasetEntity)
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
.leftJoin(datasetTestObjEntity)
.on(
datasetTestObjEntity
.datasetUid
.eq(modelDatasetMappEntity.datasetUid)
.and(
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId))
.fetch();
}
}

View File

@@ -0,0 +1,19 @@
package com.kamco.cd.training.train.dto;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class ModelTrainLinkDto {
private Long modelId;
private String trainType;
private String modelNo;
private Long datasetId;
private String targetClassCd;
private String comparePath;
private String targetPath;
private String labelPath;
private String datasetUid;
}

View File

@@ -1,5 +1,6 @@
package com.kamco.cd.training.train.service;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import java.io.IOException;
import java.nio.file.*;
import java.util.List;
@@ -19,6 +20,92 @@ public class TmpDatasetService {
@Value("${train.docker.basePath}")
private String trainBaseDir;
/**
* train, val, test 폴더별로 link
*
* @param uid 임시폴더 uuid
* @param type train, val, test
* @param links tif pull path
* @return
* @throws IOException
*/
public String buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
throws IOException {
if (links == null || links.isEmpty()) {
throw new IOException("links is empty");
}
Path tmp = Path.of(trainBaseDir, "tmp", uid);
long hardlinksMade = 0;
for (ModelTrainLinkDto dto : links) {
if (type == null) {
log.warn("SKIP - trainType null: {}", dto);
continue;
}
// type별 디렉토리 생성
Files.createDirectories(tmp.resolve(type).resolve("input1"));
Files.createDirectories(tmp.resolve(type).resolve("input2"));
Files.createDirectories(tmp.resolve(type).resolve("label"));
// comparePath → input1
hardlinksMade += link(tmp, type, "input1", dto.getComparePath());
// targetPath → input2
hardlinksMade += link(tmp, type, "input2", dto.getTargetPath());
// labelPath → label
hardlinksMade += link(tmp, type, "label", dto.getLabelPath());
}
if (hardlinksMade == 0) {
throw new IOException("No hardlinks created.");
}
log.info("tmp dataset created: {}, hardlinksMade={}", tmp, hardlinksMade);
return uid;
}
private long link(Path tmp, String type, String part, String fullPath) throws IOException {
if (fullPath == null || fullPath.isBlank()) return 0;
Path src = Path.of(fullPath);
if (!Files.isRegularFile(src)) {
log.warn("SKIP (not file): {}", src);
return 0;
}
String fileName = src.getFileName().toString();
Path dst = tmp.resolve(type).resolve(part).resolve(fileName);
// 충돌 시 덮어쓰기
if (Files.exists(dst)) {
Files.delete(dst);
}
Files.createLink(dst, src);
return 1;
}
private String safe(String s) {
return (s == null || s.isBlank()) ? null : s.trim();
}
/**
* request 전체 폴더 link
*
* @param uid
* @param datasetUids
* @return
* @throws IOException
*/
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
log.info("========== buildTmpDatasetHardlink START ==========");
@@ -37,7 +124,7 @@ public class TmpDatasetService {
// tmp 디렉토리 준비
for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label", "label-json")) {
for (String part : List.of("input1", "input2", "label")) {
Path dir = tmp.resolve(type).resolve(part);
Files.createDirectories(dir);
log.info("createDirectories: {}", dir);
@@ -70,7 +157,7 @@ public class TmpDatasetService {
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
for (String type : List.of("train", "val", "test")) {
for (String part : List.of("input1", "input2", "label", "label-json")) {
for (String part : List.of("input1", "input2", "label")) {
Path srcDir = srcRoot.resolve(type).resolve(part);
if (!Files.isDirectory(srcDir)) {

View File

@@ -7,6 +7,7 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.kamco.cd.training.train.dto.TrainRunRequest;
import java.io.IOException;
import java.nio.file.Files;
@@ -234,16 +235,30 @@ public class TrainJobService {
UUID tmpUuid = UUID.randomUUID();
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
// MODELID 가져오기
// model id 가져오기
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
// model 에 연결된 dataset id 가져오기
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
try {
// 데이터셋 심볼링크 생성
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
// train path
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
// validation path
List<ModelTrainLinkDto> valList = modelTrainMngCoreService.findDatasetValPath(modelId);
// test path
List<ModelTrainLinkDto> testList = modelTrainMngCoreService.findDatasetTestPath(modelId);
// train 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "train", trainList);
// val 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "val", valList);
// test 데이터셋 심볼링크 생성
tmpDatasetService.buildTmpDatasetHardlink(raw, "test", testList);
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
updateReq.setRequestPath(pathUid);
updateReq.setRequestPath(raw);
// 학습모델을 수정한다.
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
} catch (IOException e) {