25 Commits

Author SHA1 Message Date
62398846d3 파일관리 기능 api 커밋 2026-04-06 17:36:40 +09:00
260569225c 파일관리 기능 api 커밋 2026-04-06 17:36:19 +09:00
bd6fe924de 모델 삭제할때 임시파일 경로 Null 처리 2026-04-06 14:59:27 +09:00
92492ca879 하이퍼 파라미터 컬럼 사이즈 변경 2026-04-03 16:04:49 +09:00
a5d79b2504 하이퍼 파라미터 컬럼 사이즈 변경 2026-04-03 16:03:17 +09:00
348d3d0052 데이터셋 조회 class count integer -> Long 로 변경 2026-04-03 15:33:54 +09:00
e77eae8f8b 데이터셋 조회 class count integer -> Long 로 변경 2026-04-03 15:17:17 +09:00
26d34d88eb spotless 적용 2026-04-03 10:20:40 +09:00
91f022889b Save-best, Save-best-rule 컬럼 varchar100으로 변경 2026-04-03 09:18:26 +09:00
dean
f00296cf2c welcome 2026-04-02 21:17:01 +09:00
f98f6cb038 Merge pull request 'solar -> solarCnt 변경' (#185) from feat/training_260324 into develop
Reviewed-on: #185
2026-04-02 18:57:16 +09:00
d1593e57c3 solar -> solarCnt 변경 2026-04-02 18:56:52 +09:00
732dccf2e4 Merge pull request 'solar -> solarCnt 변경' (#184) from feat/training_260324 into develop
Reviewed-on: #184
2026-04-02 18:55:14 +09:00
f6cd553af8 solar -> solarCnt 변경 2026-04-02 18:54:52 +09:00
618dbe4047 Merge pull request '데이터셋 entity 수정, 데이터셋 저장 수정' (#183) from feat/training_260324 into develop
Reviewed-on: #183
2026-04-02 18:42:55 +09:00
5546e8ef89 데이터셋 entity 수정, 데이터셋 저장 수정 2026-04-02 18:42:25 +09:00
b952ec7b47 Merge pull request '임시 데이터셋 폴더 생성 G4 추가' (#182) from feat/training_260324 into develop
Reviewed-on: #182
2026-04-02 18:05:11 +09:00
e93f533c59 임시 데이터셋 폴더 생성 G4 추가 2026-04-02 18:04:41 +09:00
a5267d8065 Merge pull request 'select-dataset-list api solarPanelCnt 추가, spotless 적용' (#181) from feat/training_260324 into develop
Reviewed-on: #181
2026-04-02 17:45:16 +09:00
71d9835b03 select-dataset-list api solarPanelCnt 추가, spotless 적용 2026-04-02 17:44:27 +09:00
39f39a4f0c Merge pull request 'ModelType enum G4 추가' (#180) from feat/training_260324 into develop
Reviewed-on: #180
2026-04-02 16:55:14 +09:00
1df7142544 ModelType enum G4 추가 2026-04-02 16:54:25 +09:00
d99e18b38c val nan 일때 오류 수정, spotless 적용 2026-04-02 14:41:13 +09:00
d6aa612494 Merge pull request 'val nan 일때 오류 수정' (#179) from feat/training_260324 into develop
Reviewed-on: #179
2026-04-02 14:26:13 +09:00
8def356323 val nan 일때 오류 수정 2026-04-02 14:19:42 +09:00
24 changed files with 811 additions and 112 deletions

View File

@@ -14,7 +14,7 @@ import lombok.Setter;
public class HyperParam { public class HyperParam {
@Schema(description = "모델", example = "G1") @Schema(description = "모델", example = "G1")
private ModelType model; // G1, G2, G3 private ModelType model; // G1, G2, G3, G4
// ------------------------- // -------------------------
// Important // Important
@@ -104,7 +104,7 @@ public class HyperParam {
@Schema(description = "Best 모델 선정 규칙", example = "less") @Schema(description = "Best 모델 선정 규칙", example = "less")
private String saveBestRule; // save_best_rule private String saveBestRule; // save_best_rule
@Schema(description = "검증 수행 주기(Epoch)", example = "10") @Schema(description = "검증 수행 주기(Epoch)", example = "1")
private Integer valInterval; // val_interval private Integer valInterval; // val_interval
@Schema(description = "로그 기록 주기(Iteration)", example = "400") @Schema(description = "로그 기록 주기(Iteration)", example = "400")

View File

@@ -12,7 +12,8 @@ import lombok.Getter;
public enum ModelType implements EnumType { public enum ModelType implements EnumType {
G1("G1"), G1("G1"),
G2("G2"), G2("G2"),
G3("G3"); G3("G3"),
G4("G4");
private String desc; private String desc;

View File

@@ -137,6 +137,6 @@ public class SecurityConfig {
/** 완전 제외(필터 자체를 안 탐) */ /** 완전 제외(필터 자체를 안 탐) */
@Bean @Bean
public WebSecurityCustomizer webSecurityCustomizer() { public WebSecurityCustomizer webSecurityCustomizer() {
return (web) -> web.ignoring().requestMatchers("/api/mapsheet/**"); return (web) -> web.ignoring().requestMatchers("/api/mapsheet/**", "/api/file-manager/**");
} }
} }

View File

@@ -248,12 +248,13 @@ public class DatasetDto {
private Integer targetYyyy; private Integer targetYyyy;
private String memo; private String memo;
@JsonIgnore private Long classCount; @JsonIgnore private Long classCount;
private Integer buildingCnt; private Long buildingCnt;
private Integer containerCnt; private Long containerCnt;
private String dataTypeName; private String dataTypeName;
private Long wasteCnt; private Long wasteCnt;
private Long landCoverCnt; private Long landCoverCnt;
private Long solarPanelCnt;
public SelectDataSet( public SelectDataSet(
String modelNo, String modelNo,
@@ -266,6 +267,7 @@ public class DatasetDto {
Integer targetYyyy, Integer targetYyyy,
String memo, String memo,
Long classCount) { Long classCount) {
this.modelNo = modelNo;
this.datasetId = datasetId; this.datasetId = datasetId;
this.uuid = uuid; this.uuid = uuid;
this.dataType = dataType; this.dataType = dataType;
@@ -280,6 +282,8 @@ public class DatasetDto {
this.wasteCnt = classCount; this.wasteCnt = classCount;
} else if (modelNo.equals(ModelType.G3.getId())) { } else if (modelNo.equals(ModelType.G3.getId())) {
this.landCoverCnt = classCount; this.landCoverCnt = classCount;
} else if (modelNo.equals(ModelType.G4.getId())) {
this.solarPanelCnt = classCount;
} }
} }
@@ -293,8 +297,9 @@ public class DatasetDto {
Integer compareYyyy, Integer compareYyyy,
Integer targetYyyy, Integer targetYyyy,
String memo, String memo,
Integer buildingCnt, Long buildingCnt,
Integer containerCnt) { Long containerCnt) {
this.modelNo = modelNo;
this.datasetId = datasetId; this.datasetId = datasetId;
this.uuid = uuid; this.uuid = uuid;
this.dataType = dataType; this.dataType = dataType;

View File

@@ -0,0 +1,125 @@
package com.kamco.cd.training.filemanager;
import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.filemanager.dto.FileManagerDto;
import com.kamco.cd.training.filemanager.service.FileManagerService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.ExampleObject;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@Slf4j
@Tag(name = "파일 관리", description = "/data 디렉토리 파일 관리 API")
@RestController
@RequestMapping("/api/file-manager")
@RequiredArgsConstructor
public class FileManagerApiController {
private final FileManagerService fileManagerService;
@Operation(
summary = "파일 목록 조회",
description = "/data 디렉토리 내 파일 및 디렉토리 목록을 조회합니다. recursive=true로 설정하면 하위 디렉토리까지 조회합니다.")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "조회 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.ListFilesRes.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청 (유효하지 않은 경로)", content = @Content),
@ApiResponse(responseCode = "404", description = "디렉토리를 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@GetMapping("/files")
public ApiResponseDto<FileManagerDto.ListFilesRes> listFiles(
@Parameter(description = "조회할 디렉토리 경로 (기본값: /data)", example = "/data/request")
@RequestParam(required = false)
String directoryPath,
@Parameter(description = "하위 디렉토리 포함 여부", example = "false")
@RequestParam(required = false, defaultValue = "false")
Boolean recursive) {
FileManagerDto.ListFilesReq request =
FileManagerDto.ListFilesReq.builder()
.directoryPath(directoryPath)
.recursive(recursive)
.build();
FileManagerDto.ListFilesRes response = fileManagerService.listFiles(request);
return ApiResponseDto.ok(response);
}
@Operation(
summary = "파일/디렉토리 삭제",
description = "지정된 파일 또는 디렉토리를 삭제합니다. recursive=true로 설정하면 디렉토리 내 모든 파일을 삭제합니다.",
requestBody =
@io.swagger.v3.oas.annotations.parameters.RequestBody(
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.DeleteFileReq.class),
examples = {
@ExampleObject(
name = "단일 파일 삭제",
value =
"""
{
"filePaths": ["/data/request/old_file.zip"],
"recursive": false
}
"""),
@ExampleObject(
name = "여러 파일 삭제",
value =
"""
{
"filePaths": ["/data/file1.txt", "/data/file2.txt"],
"recursive": false
}
"""),
@ExampleObject(
name = "디렉토리 전체 삭제",
value =
"""
{
"filePaths": ["/data/old_folder"],
"recursive": true
}
""")
})))
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "삭제 성공",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = FileManagerDto.DeleteFileRes.class))),
@ApiResponse(responseCode = "400", description = "잘못된 요청 (유효하지 않은 경로)", content = @Content),
@ApiResponse(responseCode = "404", description = "파일을 찾을 수 없음", content = @Content),
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
})
@DeleteMapping("/files")
public ApiResponseDto<FileManagerDto.DeleteFileRes> deleteFiles(
@RequestBody FileManagerDto.DeleteFileReq request) {
FileManagerDto.DeleteFileRes response = fileManagerService.deleteFiles(request);
return ApiResponseDto.ok(response);
}
}

View File

@@ -0,0 +1,134 @@
package com.kamco.cd.training.filemanager.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import java.time.LocalDateTime;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
public class FileManagerDto {
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 정보")
public static class FileInfo {
@Schema(description = "파일명", example = "dataset.zip")
private String fileName;
@Schema(description = "파일 전체 경로", example = "/data/request/dataset.zip")
private String filePath;
@Schema(description = "파일 크기 (bytes)", example = "1024000")
private Long fileSize;
@Schema(description = "파일인지 디렉토리인지 여부", example = "true")
private Boolean isFile;
@Schema(description = "디렉토리인지 여부", example = "false")
private Boolean isDirectory;
@Schema(description = "마지막 수정 시간", example = "2026-04-06T15:30:00")
private LocalDateTime lastModified;
@Schema(description = "읽기 권한", example = "true")
private Boolean readable;
@Schema(description = "쓰기 권한", example = "true")
private Boolean writable;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디렉토리 목록 조회 요청")
public static class ListFilesReq {
@Schema(description = "조회할 디렉토리 경로 (기본값: /data)", example = "/data/request")
private String directoryPath;
@Schema(description = "하위 디렉토리 포함 여부", example = "false")
private Boolean recursive;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디렉토리 목록 조회 응답")
public static class ListFilesRes {
@Schema(description = "조회된 디렉토리 경로", example = "/data/request")
private String directoryPath;
@Schema(description = "파일 목록")
private List<FileInfo> files;
@Schema(description = "총 파일 개수", example = "10")
private Integer totalCount;
@Schema(description = "총 파일 크기 (bytes)", example = "10240000")
private Long totalSize;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 삭제 요청")
public static class DeleteFileReq {
@Schema(
description = "삭제할 파일 또는 디렉토리 경로 목록",
example = "[\"/data/request/old_file.zip\", \"/data/tmp/test_folder\"]")
private List<String> filePaths;
@Schema(description = "디렉토리일 경우 하위 파일 포함 삭제 여부", example = "true")
private Boolean recursive;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "파일 삭제 응답")
public static class DeleteFileRes {
@Schema(description = "삭제 성공한 파일 경로 목록")
private List<String> deletedFiles;
@Schema(description = "삭제 실패한 파일 경로 목록")
private List<String> failedFiles;
@Schema(description = "삭제 성공 개수", example = "5")
private Integer successCount;
@Schema(description = "삭제 실패 개수", example = "0")
private Integer failureCount;
@Schema(description = "전체 메시지", example = "5개 파일 삭제 성공")
private String message;
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Schema(description = "디스크 사용량 조회 응답")
public static class DiskUsageRes {
@Schema(description = "디렉토리 경로", example = "/data")
private String directoryPath;
@Schema(description = "총 용량 (bytes)", example = "1000000000000")
private Long totalSpace;
@Schema(description = "사용 가능 용량 (bytes)", example = "500000000000")
private Long usableSpace;
@Schema(description = "사용 중인 용량 (bytes)", example = "500000000000")
private Long usedSpace;
@Schema(description = "사용률 (%)", example = "50.0")
private Double usagePercentage;
}
}

View File

@@ -0,0 +1,234 @@
package com.kamco.cd.training.filemanager.service;
import com.kamco.cd.training.filemanager.dto.FileManagerDto;
import java.io.IOException;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class FileManagerService {
private static final String BASE_DATA_PATH = "/data";
private static final long MAX_PATH_LENGTH = 500;
/**
* 디렉토리 내 파일 목록 조회
*
* @param request 조회 요청 정보
* @return 파일 목록 응답
*/
public FileManagerDto.ListFilesRes listFiles(FileManagerDto.ListFilesReq request) {
String targetPath =
request.getDirectoryPath() != null ? request.getDirectoryPath() : BASE_DATA_PATH;
boolean recursive = request.getRecursive() != null && request.getRecursive();
validatePath(targetPath);
Path directory = Paths.get(targetPath);
if (!Files.exists(directory)) {
throw new IllegalArgumentException("디렉토리가 존재하지 않습니다: " + targetPath);
}
if (!Files.isDirectory(directory)) {
throw new IllegalArgumentException("디렉토리 경로가 아닙니다: " + targetPath);
}
List<FileManagerDto.FileInfo> files = new ArrayList<>();
long totalSize = 0;
try {
if (recursive) {
// 재귀적으로 모든 하위 파일 조회
Files.walkFileTree(
directory,
new SimpleFileVisitor<>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
files.add(createFileInfo(file));
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) {
if (!dir.equals(directory)) {
files.add(createFileInfo(dir));
}
return FileVisitResult.CONTINUE;
}
});
} else {
// 현재 디렉토리의 파일만 조회
try (Stream<Path> stream = Files.list(directory)) {
stream.forEach(path -> files.add(createFileInfo(path)));
}
}
// 총 파일 크기 계산
for (FileManagerDto.FileInfo file : files) {
if (file.getIsFile() && file.getFileSize() != null) {
totalSize += file.getFileSize();
}
}
} catch (IOException e) {
log.error("파일 목록 조회 중 오류 발생: {}", targetPath, e);
throw new RuntimeException("파일 목록 조회에 실패했습니다: " + e.getMessage());
}
return FileManagerDto.ListFilesRes.builder()
.directoryPath(targetPath)
.files(files)
.totalCount(files.size())
.totalSize(totalSize)
.build();
}
/**
* 파일 또는 디렉토리 삭제
*
* @param request 삭제 요청 정보
* @return 삭제 결과
*/
public FileManagerDto.DeleteFileRes deleteFiles(FileManagerDto.DeleteFileReq request) {
List<String> deletedFiles = new ArrayList<>();
List<String> failedFiles = new ArrayList<>();
boolean recursive = request.getRecursive() != null && request.getRecursive();
for (String filePath : request.getFilePaths()) {
try {
validatePath(filePath);
Path path = Paths.get(filePath);
if (!Files.exists(path)) {
log.warn("삭제하려는 파일이 존재하지 않습니다: {}", filePath);
failedFiles.add(filePath + " (파일이 존재하지 않음)");
continue;
}
if (Files.isDirectory(path)) {
if (recursive) {
// 디렉토리 및 하위 파일 모두 삭제
deleteDirectoryRecursively(path);
deletedFiles.add(filePath);
} else {
// 빈 디렉토리만 삭제
if (isDirectoryEmpty(path)) {
Files.delete(path);
deletedFiles.add(filePath);
} else {
failedFiles.add(filePath + " (디렉토리가 비어있지 않음)");
}
}
} else {
// 파일 삭제
Files.delete(path);
deletedFiles.add(filePath);
}
log.info("파일 삭제 성공: {}", filePath);
} catch (Exception e) {
log.error("파일 삭제 실패: {}", filePath, e);
failedFiles.add(filePath + " (" + e.getMessage() + ")");
}
}
String message =
String.format("%d개 파일 삭제 성공, %d개 파일 삭제 실패", deletedFiles.size(), failedFiles.size());
return FileManagerDto.DeleteFileRes.builder()
.deletedFiles(deletedFiles)
.failedFiles(failedFiles)
.successCount(deletedFiles.size())
.failureCount(failedFiles.size())
.message(message)
.build();
}
/** FileInfo 객체 생성 */
private FileManagerDto.FileInfo createFileInfo(Path path) {
try {
BasicFileAttributes attrs = Files.readAttributes(path, BasicFileAttributes.class);
return FileManagerDto.FileInfo.builder()
.fileName(path.getFileName().toString())
.filePath(path.toString())
.fileSize(attrs.isRegularFile() ? attrs.size() : null)
.isFile(attrs.isRegularFile())
.isDirectory(attrs.isDirectory())
.lastModified(
LocalDateTime.ofInstant(
Instant.ofEpochMilli(attrs.lastModifiedTime().toMillis()),
ZoneId.systemDefault()))
.readable(Files.isReadable(path))
.writable(Files.isWritable(path))
.build();
} catch (IOException e) {
log.warn("파일 정보 조회 실패: {}", path, e);
return FileManagerDto.FileInfo.builder()
.fileName(path.getFileName().toString())
.filePath(path.toString())
.build();
}
}
/** 디렉토리 재귀 삭제 */
private void deleteDirectoryRecursively(Path directory) throws IOException {
Files.walkFileTree(
directory,
new SimpleFileVisitor<>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs)
throws IOException {
Files.delete(file);
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
Files.delete(dir);
return FileVisitResult.CONTINUE;
}
});
}
/** 디렉토리가 비어있는지 확인 */
private boolean isDirectoryEmpty(Path directory) throws IOException {
try (Stream<Path> stream = Files.list(directory)) {
return stream.findFirst().isEmpty();
}
}
/** 경로 검증 (보안) */
private void validatePath(String path) {
if (path == null || path.trim().isEmpty()) {
throw new IllegalArgumentException("경로가 비어있습니다");
}
if (path.length() > MAX_PATH_LENGTH) {
throw new IllegalArgumentException("경로가 너무 깁니다");
}
// 경로 순회 공격 방지 - 상대경로 패턴만 제한
if (path.contains("..")) {
throw new IllegalArgumentException("상대 경로(..)는 사용할 수 없습니다");
}
}
}

View File

@@ -101,7 +101,7 @@ public class HyperParamApiController {
LocalDate endDate, LocalDate endDate,
@Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false) @Parameter(description = "버전명", example = "G1_000019") @RequestParam(required = false)
String hyperVer, String hyperVer,
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1") @Parameter(description = "모델 타입 (G1, G2, G3, G4 중 하나)", example = "G1")
@RequestParam(required = false) @RequestParam(required = false)
ModelType model, ModelType model,
@Parameter( @Parameter(

View File

@@ -1,6 +1,7 @@
package com.kamco.cd.training.model; package com.kamco.cd.training.model;
import com.kamco.cd.training.common.dto.MonitorDto; import com.kamco.cd.training.common.dto.MonitorDto;
import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.common.service.SystemMonitorService; import com.kamco.cd.training.common.service.SystemMonitorService;
import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.config.api.ApiResponseDto;
import com.kamco.cd.training.dataset.dto.DatasetDto; import com.kamco.cd.training.dataset.dto.DatasetDto;
@@ -68,7 +69,7 @@ public class ModelTrainMngApiController {
@Parameter( @Parameter(
description = "모델", description = "모델",
example = "G1", example = "G1",
schema = @Schema(allowableValues = {"G1", "G2", "G3"})) schema = @Schema(allowableValues = {"G1", "G2", "G3", "G4"}))
@RequestParam(required = false) @RequestParam(required = false)
String modelNo, String modelNo,
@Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page, @Parameter(description = "페이지 번호") @RequestParam(defaultValue = "0") int page,
@@ -143,9 +144,9 @@ public class ModelTrainMngApiController {
@Parameter( @Parameter(
description = "모델 구분", description = "모델 구분",
example = "", example = "",
schema = @Schema(allowableValues = {"G1", "G2", "G3"})) schema = @Schema(allowableValues = {"G1", "G2", "G3", "G4"}))
@RequestParam @RequestParam
String modelType, ModelType modelType,
@Parameter( @Parameter(
description = "선택 구분", description = "선택 구분",
example = "", example = "",
@@ -153,7 +154,7 @@ public class ModelTrainMngApiController {
@RequestParam @RequestParam
String selectType) { String selectType) {
DatasetReq req = new DatasetReq(); DatasetReq req = new DatasetReq();
req.setModelNo(modelType); req.setModelNo(modelType.getId());
req.setDataType(selectType); req.setDataType(selectType);
return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req)); return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req));
} }

View File

@@ -139,7 +139,7 @@ public class ModelTrainMngDto {
public static class AddReq { public static class AddReq {
@NotNull @NotNull
@Schema(description = "모델 종류 G1, G2, G3", example = "G1") @Schema(description = "모델 종류 G1, G2, G3, G4", example = "G1")
private String modelNo; private String modelNo;
@NotNull @NotNull
@@ -197,10 +197,11 @@ public class ModelTrainMngDto {
@Schema(description = "폐기물", example = "0") @Schema(description = "폐기물", example = "0")
private Long wasteCnt; private Long wasteCnt;
@Schema( @Schema(description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 기타", example = "0")
description = "도로, 비닐하우스, 밭, 과수원, 초지, 숲, 물, 모재/자갈, 토분(무덤), 일반토지, 태양광, 기타",
example = "0")
private Long LandCoverCnt; private Long LandCoverCnt;
@Schema(description = "태양광", example = "0")
private Long solarCnt;
} }
@Getter @Getter

View File

@@ -51,10 +51,10 @@ public class ModelTrainMngService {
@Value("${train.docker.response_dir}") @Value("${train.docker.response_dir}")
private String responseDir; private String responseDir;
@Value("${train.docker.symbolic_link_dir}") @Value("${train.docker.symbolic_link_dir}")
private String symbolicDir; private String symbolicDir;
/** /**
* 모델학습 조회 * 모델학습 조회
* *
@@ -82,6 +82,10 @@ public class ModelTrainMngService {
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 없음"); throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "모델 없음");
} }
if (model.getRequestPath() == null) {
throw new CustomApiException("NOT_FOUND", HttpStatus.NOT_FOUND, "임시파일 경로 없음");
}
// ===== 2. 경로 생성 ===== // ===== 2. 경로 생성 =====
Path tmpBase = Path.of(symbolicDir).toAbsolutePath().normalize(); Path tmpBase = Path.of(symbolicDir).toAbsolutePath().normalize();
Path tmp = tmpBase.resolve(model.getRequestPath()).normalize(); Path tmp = tmpBase.resolve(model.getRequestPath()).normalize();
@@ -324,6 +328,8 @@ public class ModelTrainMngService {
public List<SelectDataSet> getDatasetSelectList(DatasetReq req) { public List<SelectDataSet> getDatasetSelectList(DatasetReq req) {
if (req.getModelNo().equals(ModelType.G1.getId())) { if (req.getModelNo().equals(ModelType.G1.getId())) {
return modelTrainMngCoreService.getDatasetSelectG1List(req); return modelTrainMngCoreService.getDatasetSelectG1List(req);
} else if (req.getModelNo().equals(ModelType.G4.getId())) {
return modelTrainMngCoreService.getDatasetSelectG4List(req);
} else { } else {
return modelTrainMngCoreService.getDatasetSelectG2G3List(req); return modelTrainMngCoreService.getDatasetSelectG2G3List(req);
} }

View File

@@ -154,6 +154,8 @@ public class ModelTrainMngCoreService {
datasetEntity.setWasteCnt(dataset.getSummary().getWasteCnt()); datasetEntity.setWasteCnt(dataset.getSummary().getWasteCnt());
} else if (addReq.getModelNo().equals(ModelType.G3.getId())) { } else if (addReq.getModelNo().equals(ModelType.G3.getId())) {
datasetEntity.setLandCoverCnt(dataset.getSummary().getLandCoverCnt()); datasetEntity.setLandCoverCnt(dataset.getSummary().getLandCoverCnt());
} else if (addReq.getModelNo().equals(ModelType.G4.getId())) {
datasetEntity.setSolarCnt(dataset.getSummary().getSolarCnt());
} }
datasetEntity.setCreatedUid(userUtil.getId()); datasetEntity.setCreatedUid(userUtil.getId());
@@ -337,6 +339,16 @@ public class ModelTrainMngCoreService {
return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo); return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo);
} }
/**
* 데이터셋 G4 목록
*
* @param req
* @return
*/
public List<SelectDataSet> getDatasetSelectG4List(DatasetReq req) {
return datasetRepository.getDatasetSelectG4List(req);
}
// TODO 미사용 끝 // TODO 미사용 끝
/** /**

View File

@@ -43,6 +43,9 @@ public class ModelDatasetEntity {
@Column(name = "land_cover_cnt") @Column(name = "land_cover_cnt")
private Long landCoverCnt; private Long landCoverCnt;
@Column(name = "solar_cnt")
private Long solarCnt;
@ColumnDefault("now()") @ColumnDefault("now()")
@Column(name = "created_dttm") @Column(name = "created_dttm")
private ZonedDateTime createdDttm = ZonedDateTime.now(); private ZonedDateTime createdDttm = ZonedDateTime.now();

View File

@@ -181,15 +181,15 @@ public class ModelHyperParamEntity {
private String metrics = "mFscore,mIoU"; private String metrics = "mFscore,mIoU";
/** Default: changed_fscore */ /** Default: changed_fscore */
@Size(max = 30) @Size(max = 50)
@NotNull @NotNull
@Column(name = "save_best", nullable = false, length = 30) @Column(name = "save_best", nullable = false, length = 50)
private String saveBest = "changed_fscore"; private String saveBest = "changed_fscore";
/** Default: greater */ /** Default: greater */
@Size(max = 10) @Size(max = 50)
@NotNull @NotNull
@Column(name = "save_best_rule", nullable = false, length = 10) @Column(name = "save_best_rule", nullable = false, length = 50)
private String saveBestRule = "greater"; private String saveBestRule = "greater";
/** Default: 1 */ /** Default: 1 */

View File

@@ -27,6 +27,8 @@ public interface DatasetRepositoryCustom {
List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req); List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req);
List<SelectDataSet> getDatasetSelectG4List(DatasetReq req);
Long getDatasetMaxStage(int compareYyyy, int targetYyyy); Long getDatasetMaxStage(int compareYyyy, int targetYyyy);
Long insertDatasetMngData(DatasetMngRegDto mngRegDto); Long insertDatasetMngData(DatasetMngRegDto mngRegDto);

View File

@@ -4,6 +4,7 @@ import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObj
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.DetectionClassification;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto;
import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq;
@@ -104,10 +105,6 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
builder.and(dataset.dataType.eq(req.getDataType())); builder.and(dataset.dataType.eq(req.getDataType()));
} }
if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
if (req.getIds() != null) { if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds())); builder.and(dataset.id.in(req.getIds()));
} }
@@ -126,14 +123,17 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
dataset.targetYyyy, dataset.targetYyyy,
dataset.memo, dataset.memo,
new CaseBuilder() new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("building")) .when(
.then(1) datasetObjEntity.targetClassCd.eq(DetectionClassification.BUILDING.getId()))
.otherwise(0) .then(1L)
.otherwise(0L)
.sum(), .sum(),
new CaseBuilder() new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq("container")) .when(
.then(1) datasetObjEntity.targetClassCd.eq(
.otherwise(0) DetectionClassification.CONTAINER.getId()))
.then(1L)
.otherwise(0L)
.sum())) .sum()))
.from(dataset) .from(dataset)
.leftJoin(datasetObjEntity) .leftJoin(datasetObjEntity)
@@ -249,29 +249,40 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
} }
// TODO 미사용 끝 // TODO 미사용 끝
@Override @Override
public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) { public List<SelectDataSet> getDatasetSelectG2G3List(DatasetReq req) {
String building = DetectionClassification.BUILDING.getId();
String container = DetectionClassification.CONTAINER.getId();
String waste = DetectionClassification.WASTE.getId();
String solar = DetectionClassification.SOLAR.getId();
BooleanBuilder builder = new BooleanBuilder(); BooleanBuilder builder = new BooleanBuilder();
builder.and(dataset.deleted.isFalse()); builder.and(dataset.deleted.isFalse());
NumberExpression<Long> selectedCnt = null; NumberExpression<Long> selectedCnt = null;
// G2
NumberExpression<Long> wasteCnt = NumberExpression<Long> wasteCnt =
datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum(); datasetObjEntity
.targetClassCd
NumberExpression<Long> elseCnt = .when(DetectionClassification.WASTE.getId())
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste"))
.then(1L) .then(1L)
.otherwise(0L) .otherwise(0L)
.sum(); .sum();
if (StringUtils.isNotBlank(req.getModelNo())) { // G3 (G1, G2, G4 제외)
if (req.getModelNo().equals(ModelType.G2.getId())) { NumberExpression<Long> elseCnt =
selectedCnt = wasteCnt; new CaseBuilder()
} else { .when(datasetObjEntity.targetClassCd.notIn(building, container, waste, solar))
selectedCnt = elseCnt; .then(1L)
} .otherwise(0L)
.sum();
if (req.getModelNo().equals(ModelType.G2.getId())) {
selectedCnt = wasteCnt;
} else {
selectedCnt = elseCnt;
} }
if (StringUtils.isNotBlank(req.getDataType())) { if (StringUtils.isNotBlank(req.getDataType())) {
@@ -481,4 +492,51 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
.where(dataset.uid.eq(uid), dataset.deleted.isFalse()) .where(dataset.uid.eq(uid), dataset.deleted.isFalse())
.fetchOne(); .fetchOne();
} }
@Override
public List<SelectDataSet> getDatasetSelectG4List(DatasetReq req) {
BooleanBuilder builder = new BooleanBuilder();
builder.and(dataset.deleted.isFalse());
if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) {
builder.and(dataset.dataType.eq(req.getDataType()));
}
if (req.getIds() != null) {
builder.and(dataset.id.in(req.getIds()));
}
return queryFactory
.select(
Projections.constructor(
SelectDataSet.class,
Expressions.constant(req.getModelNo()),
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.compareYyyy,
dataset.targetYyyy,
dataset.memo,
new CaseBuilder()
.when(datasetObjEntity.targetClassCd.eq(DetectionClassification.SOLAR.getId()))
.then(1L)
.otherwise(0L)
.sum()))
.from(dataset)
.leftJoin(datasetObjEntity)
.on(dataset.id.eq(datasetObjEntity.datasetUid))
.where(builder)
.groupBy(
dataset.id,
dataset.uuid,
dataset.dataType,
dataset.title,
dataset.roundNo,
dataset.memo)
.orderBy(dataset.createdDttm.desc())
.fetch();
}
} }

View File

@@ -1,9 +1,11 @@
package com.kamco.cd.training.postgres.repository.model; package com.kamco.cd.training.postgres.repository.model;
import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity; import static com.kamco.cd.training.postgres.entity.QDatasetEntity.datasetEntity;
import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObjEntity;
import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity;
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
import com.kamco.cd.training.common.enums.DetectionClassification;
import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.enums.ModelType;
import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity; import com.kamco.cd.training.postgres.entity.ModelDatasetMappEntity;
import com.kamco.cd.training.postgres.entity.QDatasetObjEntity; import com.kamco.cd.training.postgres.entity.QDatasetObjEntity;
@@ -11,6 +13,7 @@ import com.kamco.cd.training.postgres.entity.QDatasetTestObjEntity;
import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity; import com.kamco.cd.training.postgres.entity.QDatasetValObjEntity;
import com.kamco.cd.training.train.dto.ModelTrainLinkDto; import com.kamco.cd.training.train.dto.ModelTrainLinkDto;
import com.querydsl.core.types.Projections; import com.querydsl.core.types.Projections;
import com.querydsl.core.types.dsl.BooleanExpression;
import com.querydsl.jpa.impl.JPAQueryFactory; import com.querydsl.jpa.impl.JPAQueryFactory;
import java.util.List; import java.util.List;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@@ -33,9 +36,44 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
@Override @Override
public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) { public List<ModelTrainLinkDto> findDatasetTrainPath(Long modelId) {
QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity; QDatasetObjEntity datasetObjEntity = QDatasetObjEntity.datasetObjEntity;
// =====================
// 조건 분리
// =====================
BooleanExpression g1 =
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(
datasetObjEntity.targetClassCd.in(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId()));
BooleanExpression g2 =
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.eq(DetectionClassification.WASTE.getId()));
BooleanExpression g4 =
modelMasterEntity
.modelNo
.eq(ModelType.G4.getId())
.and(datasetObjEntity.targetClassCd.eq(DetectionClassification.SOLAR.getId()));
// G3 = 전체 허용 (fallback)
BooleanExpression g3 =
modelMasterEntity
.modelNo
.eq(ModelType.G3.getId())
.and(
datasetObjEntity.targetClassCd.notIn(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId(),
DetectionClassification.WASTE.getId(),
DetectionClassification.SOLAR.getId()));
return queryFactory return queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
@@ -60,17 +98,7 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
datasetObjEntity datasetObjEntity
.datasetUid .datasetUid
.eq(modelDatasetMappEntity.datasetUid) .eq(modelDatasetMappEntity.datasetUid)
.and( .and(g1.or(g2).or(g4).or(g3)))
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId)) .where(modelMasterEntity.id.eq(modelId))
.fetch(); .fetch();
} }
@@ -80,6 +108,42 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity; QDatasetValObjEntity datasetValObjEntity = QDatasetValObjEntity.datasetValObjEntity;
// =====================
// 조건 분리
// =====================
BooleanExpression g1 =
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(
datasetValObjEntity.targetClassCd.in(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId()));
BooleanExpression g2 =
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.eq(DetectionClassification.WASTE.getId()));
BooleanExpression g4 =
modelMasterEntity
.modelNo
.eq(ModelType.G4.getId())
.and(datasetValObjEntity.targetClassCd.eq(DetectionClassification.SOLAR.getId()));
// G3 = 전체 허용 (fallback)
BooleanExpression g3 =
modelMasterEntity
.modelNo
.eq(ModelType.G3.getId())
.and(
datasetValObjEntity.targetClassCd.notIn(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId(),
DetectionClassification.WASTE.getId(),
DetectionClassification.SOLAR.getId()));
return queryFactory return queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
@@ -104,17 +168,7 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
datasetValObjEntity datasetValObjEntity
.datasetUid .datasetUid
.eq(modelDatasetMappEntity.datasetUid) .eq(modelDatasetMappEntity.datasetUid)
.and( .and(g1.or(g2).or(g4).or(g3)))
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetValObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetValObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId)) .where(modelMasterEntity.id.eq(modelId))
.fetch(); .fetch();
} }
@@ -124,6 +178,42 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity; QDatasetTestObjEntity datasetTestObjEntity = QDatasetTestObjEntity.datasetTestObjEntity;
// =====================
// 조건 분리
// =====================
BooleanExpression g1 =
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(
datasetTestObjEntity.targetClassCd.in(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId()));
BooleanExpression g2 =
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.eq(DetectionClassification.WASTE.getId()));
BooleanExpression g4 =
modelMasterEntity
.modelNo
.eq(ModelType.G4.getId())
.and(datasetTestObjEntity.targetClassCd.eq(DetectionClassification.SOLAR.getId()));
// G3 = 전체 허용 (fallback)
BooleanExpression g3 =
modelMasterEntity
.modelNo
.eq(ModelType.G3.getId())
.and(
datasetTestObjEntity.targetClassCd.notIn(
DetectionClassification.CONTAINER.getId(),
DetectionClassification.BUILDING.getId(),
DetectionClassification.WASTE.getId(),
DetectionClassification.SOLAR.getId()));
return queryFactory return queryFactory
.select( .select(
Projections.constructor( Projections.constructor(
@@ -148,17 +238,7 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
datasetTestObjEntity datasetTestObjEntity
.datasetUid .datasetUid
.eq(modelDatasetMappEntity.datasetUid) .eq(modelDatasetMappEntity.datasetUid)
.and( .and(g1.or(g2).or(g4).or(g3)))
modelMasterEntity
.modelNo
.eq(ModelType.G1.getId())
.and(datasetTestObjEntity.targetClassCd.upper().in("CONTAINER", "BUILDING"))
.or(
modelMasterEntity
.modelNo
.eq(ModelType.G2.getId())
.and(datasetTestObjEntity.targetClassCd.upper().eq("WASTE")))
.or(modelMasterEntity.modelNo.eq(ModelType.G3.getId()))))
.where(modelMasterEntity.id.eq(modelId)) .where(modelMasterEntity.id.eq(modelId))
.fetch(); .fetch();
} }

View File

@@ -56,6 +56,12 @@ public class DockerTrainService {
@Value("${spring.profiles.active}") @Value("${spring.profiles.active}")
private String profile; private String profile;
@Value("${hyper.parameter.gpus}")
private String hyperGpus;
@Value("${hyper.parameter.gpu-ids}")
private String hyperGpuIds;
private final ModelTrainJobCoreService modelTrainJobCoreService; private final ModelTrainJobCoreService modelTrainJobCoreService;
/** /**
@@ -262,9 +268,9 @@ public class DockerTrainService {
c.add("-v"); c.add("-v");
c.add(basePath + ":" + basePath); // 심볼릭 링크와 연결되는 실제 파일 경로도 마운트를 해줘야 함 c.add(basePath + ":" + basePath); // 심볼릭 링크와 연결되는 실제 파일 경로도 마운트를 해줘야 함
c.add("-v"); c.add("-v");
c.add(symbolicDir + ":/data"); //요청할경로 c.add(symbolicDir + ":/data"); // 요청할경로
c.add("-v"); c.add("-v");
c.add(responseDir + ":/checkpoints"); //저장될경로 c.add(responseDir + ":/checkpoints"); // 저장될경로
// 표준입력 유지 (-it 대신 -i만 사용) // 표준입력 유지 (-it 대신 -i만 사용)
c.add("-i"); c.add("-i");
@@ -285,11 +291,13 @@ public class DockerTrainService {
// addArg(c, "--gpu-ids", req.getGpuIds()); // null // addArg(c, "--gpu-ids", req.getGpuIds()); // null
if ("prod".equals(profile)) { if ("prod".equals(profile)) {
addArg(c, "--batch-size", 2); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 addArg(c, "--batch-size", 2); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능
addArg(c, "--gpus", "1"); // 학습서버 GPU 1개인 곳은 1이어야 함
addArg(c, "--gpu-ids", "0"); // 학습서버 GPU 1개인 곳은 0이어야 함
} else { } else {
addArg(c, "--batch-size", req.getBatchSize()); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 addArg(c, "--batch-size", req.getBatchSize()); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능
} }
addArg(c, "--gpus", hyperGpus); // 학습서버 GPU 1개인 곳은 1이어야 함
addArg(c, "--gpu-ids", hyperGpuIds); // 학습서버 GPU 1개인 곳은 0이어야 함
addArg(c, "--lr", req.getLearningRate()); addArg(c, "--lr", req.getLearningRate());
addArg(c, "--backbone", req.getBackbone()); addArg(c, "--backbone", req.getBackbone());
addArg(c, "--epochs", req.getEpochs()); addArg(c, "--epochs", req.getEpochs());

View File

@@ -384,7 +384,20 @@ public class JobRecoveryOnStartupService {
return new OutputResult(false, "total-epoch-missing"); return new OutputResult(false, "total-epoch-missing");
} }
log.info("[RECOVERY] totalEpoch={}. jobId={}", totalEpoch, job.getId()); Integer valInterval = extractValInterval(job).orElse(null);
if (valInterval == null || valInterval <= 0) {
log.warn(
"[RECOVERY] valInterval missing or invalid. jobId={}, valInterval={}",
job.getId(),
valInterval);
return new OutputResult(false, "val-interval-missing");
}
log.info(
"[RECOVERY] totalEpoch={}. valInterval={}. jobId={}",
totalEpoch,
valInterval,
job.getId());
// 3) val.csv 존재 확인 // 3) val.csv 존재 확인
Path valCsv = outDir.resolve("val.csv"); Path valCsv = outDir.resolve("val.csv");
@@ -396,14 +409,17 @@ public class JobRecoveryOnStartupService {
// 4) val.csv 라인 수 확인 // 4) val.csv 라인 수 확인
long lines = countNonHeaderLines(valCsv); long lines = countNonHeaderLines(valCsv);
// expected = 실제 val 실행 횟수
int expectedLines = totalEpoch / valInterval;
log.info( log.info(
"[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}", "[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}",
job.getId(), job.getId(),
lines, lines,
totalEpoch); expectedLines);
// 5) 완료 판정 // 5) 완료 판정
if (lines == totalEpoch) { if (lines >= expectedLines) {
log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId()); log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId());
return new OutputResult(true, "ok"); return new OutputResult(true, "ok");
} }
@@ -412,7 +428,7 @@ public class JobRecoveryOnStartupService {
"[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}", "[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}",
job.getId(), job.getId(),
lines, lines,
totalEpoch); expectedLines);
return new OutputResult( return new OutputResult(
false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch); false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch);
@@ -530,4 +546,19 @@ public class JobRecoveryOnStartupService {
return reason; return reason;
} }
} }
/** paramsJson에서 valInterval 추출 */
private Optional<Integer> extractValInterval(ModelTrainJobDto job) {
Map<String, Object> params = job.getParamsJson();
if (params == null) return Optional.empty();
Object v = params.get("valInterval");
if (v == null) return Optional.empty();
try {
return Optional.of(Integer.parseInt(String.valueOf(v)));
} catch (Exception ignore) {
return Optional.empty();
}
}
} }

View File

@@ -85,34 +85,20 @@ public class ModelTrainMetricsJobService {
int epoch = Integer.parseInt(record.get("Epoch")); int epoch = Integer.parseInt(record.get("Epoch"));
float aAcc = parseFloatSafe(record.get("aAcc")); Float aAcc = parseFloatSafe(record.get("aAcc"));
float mFscore = parseFloatSafe(record.get("mFscore")); Float mFscore = parseFloatSafe(record.get("mFscore"));
float mPrecision = parseFloatSafe(record.get("mPrecision")); Float mPrecision = parseFloatSafe(record.get("mPrecision"));
float mRecall = parseFloatSafe(record.get("mRecall")); Float mRecall = parseFloatSafe(record.get("mRecall"));
float mIoU = parseFloatSafe(record.get("mIoU")); Float mIoU = parseFloatSafe(record.get("mIoU"));
float mAcc = parseFloatSafe(record.get("mAcc")); Float mAcc = parseFloatSafe(record.get("mAcc"));
float changed_fscore = parseFloatSafe(record.get("changed_fscore")); Float changed_fscore = parseFloatSafe(record.get("changed_fscore"));
float changed_precision = parseFloatSafe(record.get("changed_precision")); Float changed_precision = parseFloatSafe(record.get("changed_precision"));
float changed_recall = parseFloatSafe(record.get("changed_recall")); Float changed_recall = parseFloatSafe(record.get("changed_recall"));
float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore")); Float unchanged_fscore = parseFloatSafe(record.get("unchanged_fscore"));
float unchanged_precision = parseFloatSafe(record.get("unchanged_precision")); Float unchanged_precision = parseFloatSafe(record.get("unchanged_precision"));
float unchanged_recall = parseFloatSafe(record.get("unchanged_recall")); Float unchanged_recall = parseFloatSafe(record.get("unchanged_recall"));
// int epoch = Integer.parseInt(record.get("Epoch"));
// float aAcc = Float.parseFloat(record.get("aAcc"));
// float mFscore = Float.parseFloat(record.get("mFscore"));
// float mPrecision = Float.parseFloat(record.get("mPrecision"));
// float mRecall = Float.parseFloat(record.get("mRecall"));
// float mIoU = Float.parseFloat(record.get("mIoU"));
// float mAcc = Float.parseFloat(record.get("mAcc"));
// float changed_fscore = Float.parseFloat(record.get("changed_fscore"));
// float changed_precision = Float.parseFloat(record.get("changed_precision"));
// float changed_recall = Float.parseFloat(record.get("changed_recall"));
// float unchanged_fscore = Float.parseFloat(record.get("unchanged_fscore"));
// float unchanged_precision =
// Float.parseFloat(record.get("unchanged_precision"));
// float unchanged_recall = Float.parseFloat(record.get("unchanged_recall"));
batchArgs.add( batchArgs.add(
new Object[] { new Object[] {

View File

@@ -19,6 +19,7 @@ public class TmpDatasetService {
@Value("${train.docker.symbolic_link_dir}") @Value("${train.docker.symbolic_link_dir}")
private String symbolicDir; private String symbolicDir;
/** /**
* train, val, test 폴더별로 link * train, val, test 폴더별로 link
* *

View File

@@ -132,7 +132,9 @@ public class TrainJobWorker {
String failMsg = result.getStatus() + "\n" + result.getLogs(); String failMsg = result.getStatus() + "\n" + result.getLogs();
log.info("training fail exitCode={} Msg ={}", result.getExitCode(), failMsg); log.info("training fail exitCode={} Msg ={}", result.getExitCode(), failMsg);
if (result.getExitCode() == -1 || result.getExitCode() == 143) { if (result.getExitCode() == -1
|| result.getExitCode() == 143
|| result.getExitCode() == 137) {
// 실패 처리 // 실패 처리
modelTrainJobCoreService.markPaused( modelTrainJobCoreService.markPaused(
jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs()); jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());

View File

@@ -41,3 +41,7 @@ train:
container_prefix: kamco-cd-train container_prefix: kamco-cd-train
shm_size: 16g shm_size: 16g
ipc_host: true ipc_host: true
hyper:
parameter:
gpus: 4
gpu-ids: 0,1,2,3

View File

@@ -78,3 +78,8 @@ management:
exposure: exposure:
include: include:
- "health" - "health"
hyper:
parameter:
gpus: 1
gpu-ids: 0