feat/training_260202 #125
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.model.service;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
|
||||
import com.kamco.cd.training.dataset.dto.DatasetDto.SelectDataSet;
|
||||
import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
@@ -89,7 +90,7 @@ public class ModelTrainDetailService {
|
||||
datasetReq.setIds(datasetIds);
|
||||
datasetReq.setModelNo(modelInfo.getModelNo());
|
||||
|
||||
if (modelInfo.getModelNo().equals("G1")) {
|
||||
if (modelInfo.getModelNo().equals(ModelType.G1.getId())) {
|
||||
dataSets = mngCoreService.getDatasetSelectG1List(datasetReq);
|
||||
} else {
|
||||
dataSets = mngCoreService.getDatasetSelectG2G3List(datasetReq);
|
||||
|
||||
@@ -23,6 +23,7 @@ import com.kamco.cd.training.postgres.repository.model.ModelConfigRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetMappRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelDatasetRepository;
|
||||
import com.kamco.cd.training.postgres.repository.model.ModelMngRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.ArrayList;
|
||||
@@ -519,4 +520,34 @@ public class ModelTrainMngCoreService {
|
||||
public Long findModelStep1InProgressCnt() {
|
||||
return modelMngRepository.findModelStep1InProgressCnt();
|
||||
}
|
||||
|
||||
/**
|
||||
* train 링크할 파일 경로
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
|
||||
return modelDatasetMapRepository.findDatasetTrainPath(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* validation 링크할 파일 경로
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
|
||||
return modelDatasetMapRepository.findDatasetValPath(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* test 링크할 파일 경로
|
||||
*
|
||||
* @param modelId
|
||||
* @return
|
||||
*/
|
||||
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
|
||||
return modelDatasetMapRepository.findDatasetTestPath(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||
import java.util.List;
|
||||
|
||||
public interface ModelDatasetMappRepositoryCustom {
|
||||
List<ModelDatasetMappEntity> findByModelUid(Long modelId);
|
||||
|
||||
List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId);
|
||||
|
||||
List<ModelTrainLinkDto> findDatasetValPath(Long modelId);
|
||||
|
||||
List<ModelTrainLinkDto> findDatasetTestPath(Long modelId);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
package com.kamco.cd.training.postgres.repository.model;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
|
||||
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -22,4 +30,133 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
|
||||
|
||||
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainLinkDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.trainType,
|
||||
modelMasterEntity.modelNo,
|
||||
modelDatasetMappEntity.datasetUid,
|
||||
datasetObjEntity.targetClassCd,
|
||||
datasetObjEntity.comparePath,
|
||||
datasetObjEntity.targetPath,
|
||||
datasetObjEntity.labelPath,
|
||||
datasetEntity.uid))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelDatasetMappEntity)
|
||||
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
|
||||
.leftJoin(datasetEntity)
|
||||
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
|
||||
.leftJoin(datasetObjEntity)
|
||||
.on(
|
||||
datasetObjEntity
|
||||
.datasetUid
|
||||
.eq(modelDatasetMappEntity.datasetUid)
|
||||
.and(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G1.getId())
|
||||
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
|
||||
.or(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G2.getId())
|
||||
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
|
||||
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelTrainLinkDto> findDatasetValPath(Long modelId) {
|
||||
|
||||
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainLinkDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.trainType,
|
||||
modelMasterEntity.modelNo,
|
||||
modelDatasetMappEntity.datasetUid,
|
||||
datasetValObjEntity.targetClassCd,
|
||||
datasetValObjEntity.comparePath,
|
||||
datasetValObjEntity.targetPath,
|
||||
datasetValObjEntity.labelPath,
|
||||
datasetEntity.uid))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelDatasetMappEntity)
|
||||
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
|
||||
.leftJoin(datasetEntity)
|
||||
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
|
||||
.leftJoin(datasetValObjEntity)
|
||||
.on(
|
||||
datasetValObjEntity
|
||||
.datasetUid
|
||||
.eq(modelDatasetMappEntity.datasetUid)
|
||||
.and(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G1.getId())
|
||||
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
|
||||
.or(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G2.getId())
|
||||
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
|
||||
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelTrainLinkDto> findDatasetTestPath(Long modelId) {
|
||||
|
||||
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
|
||||
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTrainLinkDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.trainType,
|
||||
modelMasterEntity.modelNo,
|
||||
modelDatasetMappEntity.datasetUid,
|
||||
datasetTestObjEntity.targetClassCd,
|
||||
datasetTestObjEntity.comparePath,
|
||||
datasetTestObjEntity.targetPath,
|
||||
datasetTestObjEntity.labelPath,
|
||||
datasetEntity.uid))
|
||||
.from(modelMasterEntity)
|
||||
.leftJoin(modelDatasetMappEntity)
|
||||
.on(modelDatasetMappEntity.modelUid.eq(modelMasterEntity.id))
|
||||
.leftJoin(datasetEntity)
|
||||
.on(datasetEntity.id.eq(modelDatasetMappEntity.datasetUid))
|
||||
.leftJoin(datasetTestObjEntity)
|
||||
.on(
|
||||
datasetTestObjEntity
|
||||
.datasetUid
|
||||
.eq(modelDatasetMappEntity.datasetUid)
|
||||
.and(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G1.getId())
|
||||
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
|
||||
.or(
|
||||
modelMasterEntity
|
||||
.modelNo
|
||||
.eq(ModelType.G2.getId())
|
||||
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
|
||||
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class ModelTrainLinkDto {
|
||||
|
||||
private Long modelId;
|
||||
private String trainType;
|
||||
private String modelNo;
|
||||
private Long datasetId;
|
||||
private String targetClassCd;
|
||||
private String comparePath;
|
||||
private String targetPath;
|
||||
private String labelPath;
|
||||
private String datasetUid;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
import java.util.List;
|
||||
@@ -19,6 +20,92 @@ public class TmpDatasetService {
|
||||
@Value("${train.docker.basePath}")
|
||||
private String trainBaseDir;
|
||||
|
||||
/**
|
||||
* train, val, test 폴더별로 link
|
||||
*
|
||||
* @param uid 임시폴더 uuid
|
||||
* @param type train, val, test
|
||||
* @param links tif pull path
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public String buildTmpDatasetHardlink(String uid, String type, List<ModelTrainLinkDto> links)
|
||||
throws IOException {
|
||||
|
||||
if (links == null || links.isEmpty()) {
|
||||
throw new IOException("links is empty");
|
||||
}
|
||||
|
||||
Path tmp = Path.of(trainBaseDir, "tmp", uid);
|
||||
|
||||
long hardlinksMade = 0;
|
||||
|
||||
for (ModelTrainLinkDto dto : links) {
|
||||
|
||||
if (type == null) {
|
||||
log.warn("SKIP - trainType null: {}", dto);
|
||||
continue;
|
||||
}
|
||||
|
||||
// type별 디렉토리 생성
|
||||
Files.createDirectories(tmp.resolve(type).resolve("input1"));
|
||||
Files.createDirectories(tmp.resolve(type).resolve("input2"));
|
||||
Files.createDirectories(tmp.resolve(type).resolve("label"));
|
||||
|
||||
// comparePath → input1
|
||||
hardlinksMade += link(tmp, type, "input1", dto.getComparePath());
|
||||
|
||||
// targetPath → input2
|
||||
hardlinksMade += link(tmp, type, "input2", dto.getTargetPath());
|
||||
|
||||
// labelPath → label
|
||||
hardlinksMade += link(tmp, type, "label", dto.getLabelPath());
|
||||
}
|
||||
|
||||
if (hardlinksMade == 0) {
|
||||
throw new IOException("No hardlinks created.");
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}, hardlinksMade={}", tmp, hardlinksMade);
|
||||
return uid;
|
||||
}
|
||||
|
||||
private long link(Path tmp, String type, String part, String fullPath) throws IOException {
|
||||
|
||||
if (fullPath == null || fullPath.isBlank()) return 0;
|
||||
|
||||
Path src = Path.of(fullPath);
|
||||
|
||||
if (!Files.isRegularFile(src)) {
|
||||
log.warn("SKIP (not file): {}", src);
|
||||
return 0;
|
||||
}
|
||||
|
||||
String fileName = src.getFileName().toString();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(fileName);
|
||||
|
||||
// 충돌 시 덮어쓰기
|
||||
if (Files.exists(dst)) {
|
||||
Files.delete(dst);
|
||||
}
|
||||
|
||||
Files.createLink(dst, src);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
private String safe(String s) {
|
||||
return (s == null || s.isBlank()) ? null : s.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* request 전체 폴더 link
|
||||
*
|
||||
* @param uid
|
||||
* @param datasetUids
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
log.info("========== buildTmpDatasetHardlink START ==========");
|
||||
@@ -37,7 +124,7 @@ public class TmpDatasetService {
|
||||
|
||||
// tmp 디렉토리 준비
|
||||
for (String type : List.of("train", "val", "test")) {
|
||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
Path dir = tmp.resolve(type).resolve(part);
|
||||
Files.createDirectories(dir);
|
||||
log.info("createDirectories: {}", dir);
|
||||
@@ -70,7 +157,7 @@ public class TmpDatasetService {
|
||||
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
|
||||
|
||||
for (String type : List.of("train", "val", "test")) {
|
||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
if (!Files.isDirectory(srcDir)) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
@@ -234,16 +235,30 @@ public class TrainJobService {
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
|
||||
|
||||
// MODELID 가져오기
|
||||
// model id 가져오기
|
||||
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
|
||||
// model 에 연결된 dataset id 가져오기
|
||||
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
// train path
|
||||
List<ModelTrainLinkDto> trainList = modelTrainMngCoreService.findDatasetTrainPath(modelId);
|
||||
// validation path
|
||||
List<ModelTrainLinkDto> valList = modelTrainMngCoreService.findDatasetValPath(modelId);
|
||||
// test path
|
||||
List<ModelTrainLinkDto> testList = modelTrainMngCoreService.findDatasetTestPath(modelId);
|
||||
|
||||
// train 데이터셋 심볼링크 생성
|
||||
tmpDatasetService.buildTmpDatasetHardlink(raw, "train", trainList);
|
||||
// val 데이터셋 심볼링크 생성
|
||||
tmpDatasetService.buildTmpDatasetHardlink(raw, "val", valList);
|
||||
// test 데이터셋 심볼링크 생성
|
||||
tmpDatasetService.buildTmpDatasetHardlink(raw, "test", testList);
|
||||
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(pathUid);
|
||||
updateReq.setRequestPath(raw);
|
||||
|
||||
// 학습모델을 수정한다.
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
|
||||
Reference in New Issue
Block a user