Compare commits
86 Commits
feat/dean/
...
da9d47ae4a
| Author | SHA1 | Date | |
|---|---|---|---|
| da9d47ae4a | |||
| bf8515163c | |||
| 7d6a77bf2a | |||
| 26828d0968 | |||
| 2691f6ce16 | |||
| e2dbae15c0 | |||
| b246034632 | |||
| 7e5aa5e713 | |||
| 060a815e1c | |||
| 687ea82d78 | |||
| 1eb4d04779 | |||
| f30c0c6d45 | |||
| 12994aab60 | |||
| 4ac0f19908 | |||
| 11d3afe295 | |||
| 9e5e7595eb | |||
| 1e62a8b097 | |||
| 9cd9274e99 | |||
| 26a4623aa8 | |||
| 5d82f3ecfe | |||
| ce6e4f5aea | |||
| 2ce249ab33 | |||
| c2215836c0 | |||
| e34bf68de0 | |||
| 8c19c996f7 | |||
| 862bda0cb9 | |||
| b5ce3ab1fb | |||
| 90f7b17d07 | |||
| e1ceb769dd | |||
| 2128baa46a | |||
| 4219b88fb3 | |||
| 4f94c99b64 | |||
| 875c30f467 | |||
| d42e1afbd4 | |||
| b3b8016673 | |||
| 2b29cd1ac6 | |||
| 79e8259f28 | |||
| 9206fff5d0 | |||
| 032c82c2f0 | |||
| 6204a6e5fa | |||
| 4d9c9a86b4 | |||
| 83204abfe9 | |||
| 5b682c1386 | |||
| 452494d44d | |||
| 8ada26448b | |||
| e442f105bc | |||
| 5e0a771848 | |||
| b4c2685059 | |||
| e238f3ca88 | |||
| 97b06eb3b3 | |||
| ad32ca18ca | |||
| 98a1283ebe | |||
| a10fccaae3 | |||
| c3c9191d9d | |||
| 9fd5a15a72 | |||
| 12f9de7367 | |||
| 5455da1e96 | |||
| 9e803661cd | |||
| b0cf9e77ec | |||
| c92426aefc | |||
| d5b2b8ecec | |||
| 6185a18a7c | |||
| 49d3e37458 | |||
| 1fb10830b9 | |||
| d7766edd24 | |||
| 0bc4453c9c | |||
| ae0d30e5da | |||
| 37d776dd2c | |||
| 3106d36431 | |||
| ed48f697a4 | |||
| da92b28d97 | |||
| 6c865d26fd | |||
| e3f00876f1 | |||
| 16e156b5b4 | |||
| 60962bbc75 | |||
| 6a939118ff | |||
| 64d37dcc08 | |||
| 0c0ae16c2b | |||
| a2490f30e6 | |||
| 953f95aed6 | |||
| bd04e1f4e8 | |||
| 85633c8bab | |||
| 5fc15937c0 | |||
| 8b3940b446 | |||
| 201cfefb6b | |||
| 9958b0999a |
@@ -1,6 +1,11 @@
|
||||
# Stage 1: Build stage (gradle build는 Jenkins에서 이미 수행)
|
||||
FROM eclipse-temurin:21-jre-jammy
|
||||
|
||||
# docker CLI 설치 (컨테이너에서 호스트 Docker 제어용) 260212 추가
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends docker.io ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 작업 디렉토리 설정
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ services:
|
||||
- /mnt/nfs_share/model_output:/app/model-outputs
|
||||
- /mnt/nfs_share/train_dataset:/app/train-dataset
|
||||
- /home/kcomu/data:/home/kcomu/data
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
networks:
|
||||
- kamco-cds
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.kamco.cd.training.common.download;
|
||||
|
||||
import com.kamco.cd.training.common.download.dto.DownloadSpec;
|
||||
import com.kamco.cd.training.common.utils.UserUtil;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DownloadExecutor {
|
||||
|
||||
private final UserUtil userUtil;
|
||||
|
||||
public ResponseEntity<StreamingResponseBody> stream(DownloadSpec spec) throws IOException {
|
||||
|
||||
if (!Files.isReadable(spec.filePath())) {
|
||||
return ResponseEntity.notFound().build();
|
||||
}
|
||||
|
||||
StreamingResponseBody body =
|
||||
os -> {
|
||||
try (InputStream in = Files.newInputStream(spec.filePath())) {
|
||||
in.transferTo(os);
|
||||
os.flush();
|
||||
} catch (Exception e) {
|
||||
// 고용량은 중간 끊김 흔하니까 throw 금지
|
||||
}
|
||||
};
|
||||
|
||||
String fileName =
|
||||
spec.downloadName() != null
|
||||
? spec.downloadName()
|
||||
: spec.filePath().getFileName().toString();
|
||||
|
||||
return ResponseEntity.ok()
|
||||
.contentType(
|
||||
spec.contentType() != null ? spec.contentType() : MediaType.APPLICATION_OCTET_STREAM)
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"")
|
||||
.body(body);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.kamco.cd.training.common.download;
|
||||
|
||||
import org.springframework.util.AntPathMatcher;
|
||||
|
||||
public final class DownloadPaths {
|
||||
private DownloadPaths() {}
|
||||
|
||||
public static final String[] PATTERNS = {
|
||||
"/api/inference/download/**", "/api/training-data/stage/download/**"
|
||||
};
|
||||
|
||||
public static boolean matches(String uri) {
|
||||
AntPathMatcher m = new AntPathMatcher();
|
||||
for (String p : PATTERNS) {
|
||||
if (m.match(p, uri)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package com.kamco.cd.training.common.download;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
import org.springframework.core.io.FileSystemResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.core.io.support.ResourceRegion;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpRange;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class RangeDownloadResponder {
|
||||
|
||||
public ResponseEntity<?> buildZipResponse(
|
||||
Path filePath, String downloadFileName, HttpServletRequest request) throws IOException {
|
||||
|
||||
if (!Files.isRegularFile(filePath)) {
|
||||
return ResponseEntity.notFound().build();
|
||||
}
|
||||
|
||||
long totalSize = Files.size(filePath);
|
||||
Resource resource = new FileSystemResource(filePath);
|
||||
|
||||
String disposition = "attachment; filename=\"" + downloadFileName + "\"";
|
||||
String rangeHeader = request.getHeader(HttpHeaders.RANGE);
|
||||
|
||||
// 🔥 공통 헤더 (여기 고정)
|
||||
ResponseEntity.BodyBuilder base =
|
||||
ResponseEntity.ok()
|
||||
.contentType(MediaType.APPLICATION_OCTET_STREAM)
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
|
||||
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
|
||||
.header("Access-Control-Expose-Headers", "Content-Disposition")
|
||||
.header("X-Accel-Buffering", "no");
|
||||
|
||||
if (rangeHeader == null || rangeHeader.isBlank()) {
|
||||
return base.contentLength(totalSize).body(resource);
|
||||
}
|
||||
|
||||
List<HttpRange> ranges;
|
||||
try {
|
||||
ranges = HttpRange.parseRanges(rangeHeader);
|
||||
} catch (IllegalArgumentException ex) {
|
||||
return ResponseEntity.status(416)
|
||||
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
|
||||
.header("X-Accel-Buffering", "no")
|
||||
.build();
|
||||
}
|
||||
|
||||
HttpRange range = ranges.get(0);
|
||||
|
||||
long start = range.getRangeStart(totalSize);
|
||||
long end = range.getRangeEnd(totalSize);
|
||||
|
||||
if (start >= totalSize) {
|
||||
return ResponseEntity.status(416)
|
||||
.header(HttpHeaders.CONTENT_RANGE, "bytes */" + totalSize)
|
||||
.header("X-Accel-Buffering", "no")
|
||||
.build();
|
||||
}
|
||||
|
||||
long regionLength = end - start + 1;
|
||||
ResourceRegion region = new ResourceRegion(resource, start, regionLength);
|
||||
|
||||
return ResponseEntity.status(206)
|
||||
.contentType(MediaType.APPLICATION_OCTET_STREAM)
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, disposition)
|
||||
.header(HttpHeaders.ACCEPT_RANGES, "bytes")
|
||||
.header("Access-Control-Expose-Headers", "Content-Disposition")
|
||||
.header("X-Accel-Buffering", "no")
|
||||
.header(HttpHeaders.CONTENT_RANGE, "bytes " + start + "-" + end + "/" + totalSize)
|
||||
.contentLength(regionLength)
|
||||
.body(region);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.kamco.cd.training.common.download.dto;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.UUID;
|
||||
import org.springframework.http.MediaType;
|
||||
|
||||
public record DownloadSpec(
|
||||
UUID uuid, // 다운로드 식별(로그/정책용)
|
||||
Path filePath, // 실제 파일 경로
|
||||
String downloadName, // 사용자에게 보일 파일명
|
||||
MediaType contentType // 보통 OCTET_STREAM
|
||||
) {}
|
||||
@@ -17,7 +17,10 @@ public enum ModelType implements EnumType {
|
||||
private String desc;
|
||||
|
||||
public static ModelType getValueData(String modelNo) {
|
||||
return Arrays.stream(ModelType.values()).filter(m -> m.getId().equals(modelNo)).findFirst().orElse(G1);
|
||||
return Arrays.stream(ModelType.values())
|
||||
.filter(m -> m.getId().equals(modelNo))
|
||||
.findFirst()
|
||||
.orElse(G1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -78,7 +78,8 @@ public class SecurityConfig {
|
||||
"/v3/api-docs/**",
|
||||
"/api/members/*/password",
|
||||
"/api/upload/chunk-upload-dataset",
|
||||
"/api/upload/chunk-upload-complete")
|
||||
"/api/upload/chunk-upload-complete",
|
||||
"/download_progress_test.html")
|
||||
.permitAll()
|
||||
|
||||
// default
|
||||
|
||||
@@ -217,7 +217,7 @@ public class DatasetApiController {
|
||||
public ApiResponseDto<ApiResponseDto.ResponseObj> insertDataset(
|
||||
@RequestBody @Valid DatasetDto.AddReq addReq) {
|
||||
|
||||
return ApiResponseDto.ok(datasetService.insertDataset(addReq));
|
||||
return ApiResponseDto.okObject(datasetService.insertDataset(addReq));
|
||||
}
|
||||
|
||||
@Operation(summary = "객체별 파일 Path 조회", description = "파일 Path 조회")
|
||||
|
||||
@@ -77,9 +77,16 @@ public class DatasetDto {
|
||||
}
|
||||
|
||||
public String getTotalSize(Long totalSize) {
|
||||
if (totalSize == null) return "0G";
|
||||
if (totalSize == null || totalSize <= 0) return "0M";
|
||||
|
||||
double giga = totalSize / (1024.0 * 1024 * 1024);
|
||||
return String.format("%.2fG", giga);
|
||||
|
||||
if (giga >= 1) {
|
||||
return String.format("%.2fG", giga);
|
||||
} else {
|
||||
double mega = totalSize / (1024.0 * 1024);
|
||||
return String.format("%.2fM", mega);
|
||||
}
|
||||
}
|
||||
|
||||
public String getStatus(String status) {
|
||||
|
||||
@@ -208,6 +208,13 @@ public class DatasetService {
|
||||
Long datasetUid = null; // master id 값, 등록하면서 가져올 예정
|
||||
|
||||
try {
|
||||
// 같은 uid 로 등록한 파일이 있는지 확인
|
||||
Long existsCnt =
|
||||
datasetCoreService.findDatasetByUidExistsCnt(addReq.getFileName().replace(".zip", ""));
|
||||
if (existsCnt > 0) {
|
||||
return new ResponseObj(ApiResponseCode.DUPLICATE_DATA, "이미 등록된 회차 데이터 파일입니다. 확인 부탁드립니다.");
|
||||
}
|
||||
|
||||
// 압축 해제
|
||||
FIleChecker.unzip(addReq.getFileName(), addReq.getFilePath());
|
||||
|
||||
|
||||
@@ -88,9 +88,8 @@ public class HyperParamApiController {
|
||||
@ApiResponse(responseCode = "404", description = "하이퍼파라미터를 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("{model}/list")
|
||||
@GetMapping("/list")
|
||||
public ApiResponseDto<Page<List>> getHyperParam(
|
||||
@PathVariable ModelType model,
|
||||
@Parameter(
|
||||
description = "구분 CREATE_DATE(생성일), LAST_USED_DATE(최근사용일)",
|
||||
example = "CREATE_DATE")
|
||||
@@ -102,6 +101,9 @@ public class HyperParamApiController {
|
||||
LocalDate endDate,
|
||||
@Parameter(description = "버전명", example = "G_000001") @RequestParam(required = false)
|
||||
String hyperVer,
|
||||
@Parameter(description = "모델 타입 (G1, G2, G3 중 하나)", example = "G1")
|
||||
@RequestParam(required = false)
|
||||
ModelType model,
|
||||
@Parameter(
|
||||
description = "정렬",
|
||||
example = "createdDttm desc",
|
||||
@@ -182,10 +184,8 @@ public class HyperParamApiController {
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/init/{model}")
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(
|
||||
@PathVariable ModelType model
|
||||
public ApiResponseDto<HyperParamDto.Basic> getInitHyperParam(@PathVariable ModelType model) {
|
||||
|
||||
) {
|
||||
return ApiResponseDto.ok(hyperParamService.getInitHyperParam(model));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,9 +110,11 @@ public class HyperParamDto {
|
||||
@AllArgsConstructor
|
||||
public static class List {
|
||||
private UUID uuid;
|
||||
private ModelType model;
|
||||
private String hyperVer;
|
||||
@JsonFormatDttm private ZonedDateTime createDttm;
|
||||
@JsonFormatDttm private ZonedDateTime lastUsedDttm;
|
||||
private String memo;
|
||||
private Long m1UseCnt;
|
||||
private Long m2UseCnt;
|
||||
private Long m3UseCnt;
|
||||
|
||||
@@ -1,26 +1,38 @@
|
||||
package com.kamco.cd.training.model;
|
||||
|
||||
import com.kamco.cd.training.common.download.RangeDownloadResponder;
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.TransferDetailDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.model.service.ModelTrainDetailService;
|
||||
import com.kamco.cd.training.model.service.ModelTrainMngService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.enums.ParameterIn;
|
||||
import io.swagger.v3.oas.annotations.media.ArraySchema;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.apache.coyote.BadRequestException;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
@@ -32,6 +44,11 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
@RequestMapping("/api/models")
|
||||
public class ModelTrainDetailApiController {
|
||||
private final ModelTrainDetailService modelTrainDetailService;
|
||||
private final ModelTrainMngService modelTrainMngService;
|
||||
private final RangeDownloadResponder rangeDownloadResponder;
|
||||
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
@Operation(summary = "모델학습관리> 모델관리 > 상세정보탭 > 학습 진행정보", description = "학습 진행정보, 모델학습 정보 API")
|
||||
@ApiResponses(
|
||||
@@ -222,4 +239,69 @@ public class ModelTrainDetailApiController {
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainBestEpoch(uuid));
|
||||
}
|
||||
|
||||
@Operation(
|
||||
summary = "학습데이터 파일 다운로드",
|
||||
description = "학습데이터 파일 다운로드",
|
||||
parameters = {
|
||||
@Parameter(
|
||||
name = "kamco-download-uuid",
|
||||
in = ParameterIn.HEADER,
|
||||
required = true,
|
||||
description = "다운로드 요청 UUID",
|
||||
schema =
|
||||
@Schema(
|
||||
type = "string",
|
||||
format = "uuid",
|
||||
example = "6d8d49dc-0c9d-4124-adc7-b9ca610cc394"))
|
||||
})
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "학습데이터 zip파일 다운로드",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/octet-stream",
|
||||
schema = @Schema(type = "string", format = "binary"))),
|
||||
@ApiResponse(responseCode = "404", description = "파일 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/download/{uuid}")
|
||||
public ResponseEntity<?> download(@PathVariable UUID uuid, HttpServletRequest request)
|
||||
throws IOException {
|
||||
|
||||
Basic info = modelTrainDetailService.findByModelByUUID(uuid);
|
||||
Path zipPath =
|
||||
Paths.get(responseDir)
|
||||
.resolve(String.valueOf(info.getUuid()))
|
||||
.resolve(info.getModelVer() + ".zip");
|
||||
|
||||
if (!Files.isRegularFile(zipPath)) {
|
||||
throw new BadRequestException();
|
||||
}
|
||||
|
||||
return rangeDownloadResponder.buildZipResponse(zipPath, info.getModelVer() + ".zip", request);
|
||||
}
|
||||
|
||||
@Operation(summary = "모델관리 > 모델 상세 > 파일 정보", description = "모델 상세 > 파일 정보 API")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "조회 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = TransferDetailDto.class))),
|
||||
@ApiResponse(responseCode = "404", description = "데이터셋을 찾을 수 없음", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/file-info/{uuid}")
|
||||
public ApiResponseDto<ModelFileInfo> getModelTrainFileInfo(
|
||||
@Parameter(description = "모델 uuid", example = "95cb116c-380a-41c0-98d8-4d1142f15bbf")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
return ApiResponseDto.ok(modelTrainDetailService.getModelTrainFileInfo(uuid));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import com.kamco.cd.training.model.dto.ModelConfigDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.Basic;
|
||||
import com.kamco.cd.training.model.service.ModelTrainMngService;
|
||||
import com.kamco.cd.training.train.service.ModelTestMetricsJobService;
|
||||
import com.kamco.cd.training.train.service.ModelTrainMetricsJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
@@ -16,6 +18,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.validation.Valid;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -35,6 +38,8 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
@RequestMapping("/api/models")
|
||||
public class ModelTrainMngApiController {
|
||||
private final ModelTrainMngService modelTrainMngService;
|
||||
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
|
||||
private final ModelTestMetricsJobService modelTestMetricsJobService;
|
||||
|
||||
@Operation(summary = "모델학습 목록 조회", description = "모델학습 목록 조회 API")
|
||||
@ApiResponses(
|
||||
@@ -79,7 +84,8 @@ public class ModelTrainMngApiController {
|
||||
@DeleteMapping("/{uuid}")
|
||||
public ApiResponseDto<Void> deleteModelTrain(
|
||||
@Parameter(description = "학습 모델 uuid", example = "f2b02229-90f2-45f5-92ea-c56cf1c29f79")
|
||||
@PathVariable UUID uuid) {
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
modelTrainMngService.deleteModelTrain(uuid);
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
@@ -166,4 +172,44 @@ public class ModelTrainMngApiController {
|
||||
public ApiResponseDto<Long> findModelStep1InProgressCnt() {
|
||||
return ApiResponseDto.ok(modelTrainMngService.findModelStep1InProgressCnt());
|
||||
}
|
||||
|
||||
@Operation(
|
||||
summary = "스케줄러 findTrainValidMetricCsvFiles",
|
||||
description = "스케줄러 findTrainValidMetricCsvFiles")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "검색 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = Long.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/schedule-trainvalid")
|
||||
public ApiResponseDto<Long> findTrainValidMetricCsvFiles() {
|
||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
|
||||
@Operation(summary = "스케줄러 findTestMetricCsvFiles", description = "스케줄러 findTestMetricCsvFiles")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "검색 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = Long.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping("/schedule-test")
|
||||
public ApiResponseDto<Long> findTestValidMetricCsvFiles() throws IOException {
|
||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||
return ApiResponseDto.ok(null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,4 +245,13 @@ public class ModelTrainDetailDto {
|
||||
private Float iou;
|
||||
private Float accuracy;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class ModelFileInfo {
|
||||
private Boolean fileExistsYn;
|
||||
private String fileName;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ public class ModelTrainMngDto {
|
||||
private String trainType;
|
||||
private String modelNo;
|
||||
private Long currentAttemptId;
|
||||
private String requestPath;
|
||||
|
||||
public String getStatusName() {
|
||||
if (this.statusCd == null || this.statusCd.isBlank()) return null;
|
||||
@@ -60,7 +61,7 @@ public class ModelTrainMngDto {
|
||||
}
|
||||
}
|
||||
|
||||
public String getStep2StatusNAme() {
|
||||
public String getStep2StatusName() {
|
||||
if (this.step2Status == null || this.step2Status.isBlank()) return null;
|
||||
try {
|
||||
return TrainStatusType.valueOf(this.step2Status).getText(); // 또는 getName()
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -116,4 +117,8 @@ public class ModelTrainDetailService {
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
|
||||
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
|
||||
return modelTrainDetailCoreService.getModelTrainFileInfo(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,8 +12,7 @@ import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.core.HyperParamCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import com.kamco.cd.training.train.service.TmpDatasetService;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -94,22 +93,6 @@ public class ModelTrainMngService {
|
||||
// 모델 config 저장
|
||||
modelTrainMngCoreService.saveModelConfig(modelId, req.getModelConfig());
|
||||
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
|
||||
|
||||
List<String> uids =
|
||||
modelTrainMngCoreService.findDatasetUid(req.getTrainingDataset().getDatasetList());
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return modelUuid;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package com.kamco.cd.training.model.service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public Path buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
// 환경에 맞게 yml로 빼는 걸 추천
|
||||
Path BASE = Paths.get(requestDir);
|
||||
Path tmp = BASE.resolve("tmp").resolve(uid);
|
||||
|
||||
// mkdir -p "$TMP"/train/{input1,input2,label} ...
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
Files.createDirectories(tmp.resolve(type).resolve(part));
|
||||
}
|
||||
}
|
||||
|
||||
for (String id : datasetUids) {
|
||||
Path srcRoot = BASE.resolve(id);
|
||||
|
||||
for (String type : List.of("train", "val")) {
|
||||
for (String part : List.of("input1", "input2", "label")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
|
||||
// zsh NULL_GLOB: 폴더가 없으면 그냥 continue
|
||||
if (!Files.isDirectory(srcDir)) continue;
|
||||
|
||||
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||
for (Path f : stream) {
|
||||
if (!Files.isRegularFile(f)) continue;
|
||||
|
||||
String dstName = id + "__" + f.getFileName();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||
|
||||
// 이미 있으면 스킵(원하면 덮어쓰기 로직으로 바꿀 수 있음)
|
||||
if (Files.exists(dst)) continue;
|
||||
|
||||
// ln -s "$f" "$dst" 와 동일
|
||||
Files.createSymbolicLink(dst, f.toAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}", tmp);
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
@@ -246,4 +246,8 @@ public class DatasetCoreService
|
||||
public void insertDatasetValObj(DatasetObjRegDto objRegDto) {
|
||||
datasetObjRepository.insertDatasetValObj(objRegDto);
|
||||
}
|
||||
|
||||
public Long findDatasetByUidExistsCnt(String uid) {
|
||||
return datasetRepository.findDatasetByUidExistsCnt(uid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,15 +50,15 @@ public class HyperParamCoreService {
|
||||
/**
|
||||
* 하이퍼파라미터 수정
|
||||
*
|
||||
* @param uuid uuid
|
||||
* @param uuid uuid
|
||||
* @param createReq 등록 요청
|
||||
* @return ver
|
||||
*/
|
||||
public String updateHyperParam(UUID uuid, HyperParam createReq) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
if (entity.getIsDefault()) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY_UPDATE", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
@@ -72,7 +72,6 @@ public class HyperParamCoreService {
|
||||
return entity.getHyperVer();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 하이퍼파라미터 삭제
|
||||
*
|
||||
@@ -80,15 +79,15 @@ public class HyperParamCoreService {
|
||||
*/
|
||||
public void deleteHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
|
||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
// }
|
||||
// if (entity.getHyperVer().equals("HPs_0001")) {
|
||||
// throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
// }
|
||||
|
||||
//디폴트면 삭제불가
|
||||
// 디폴트면 삭제불가
|
||||
if (entity.getIsDefault()) {
|
||||
throw new CustomApiException("UNPROCESSABLE_ENTITY", HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
@@ -105,9 +104,10 @@ public class HyperParamCoreService {
|
||||
*/
|
||||
public HyperParamDto.Basic getInitHyperParam(ModelType model) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.getHyperparamByType(model)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
hyperParamRepository.getHyperparamByType(model).stream()
|
||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
@@ -118,9 +118,9 @@ public class HyperParamCoreService {
|
||||
*/
|
||||
public HyperParamDto.Basic getHyperParam(UUID uuid) {
|
||||
ModelHyperParamEntity entity =
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
hyperParamRepository
|
||||
.findHyperParamByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
@@ -143,16 +143,16 @@ public class HyperParamCoreService {
|
||||
*/
|
||||
public String getFirstHyperParamVersion(ModelType model) {
|
||||
return hyperParamRepository
|
||||
.findHyperParamVerByModelType(model)
|
||||
.map(ModelHyperParamEntity::getHyperVer)
|
||||
.map(ver -> increase(ver, model))
|
||||
.orElse(model.name() + "_000001");
|
||||
.findHyperParamVerByModelType(model)
|
||||
.map(ModelHyperParamEntity::getHyperVer)
|
||||
.map(ver -> increase(ver, model))
|
||||
.orElse(model.name() + "_000001");
|
||||
}
|
||||
|
||||
/**
|
||||
* 하이퍼 파라미터의 버전을 증가시킨다.
|
||||
*
|
||||
* @param hyperVer 현재 버전
|
||||
* @param hyperVer 현재 버전
|
||||
* @param modelType 모델 타입
|
||||
* @return 증가된 버전
|
||||
*/
|
||||
@@ -214,5 +214,4 @@ public class HyperParamCoreService {
|
||||
// memo
|
||||
entity.setMemo(src.getMemo());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.kamco.cd.training.postgres.core;
|
||||
|
||||
import com.kamco.cd.training.postgres.repository.train.ModelTestMetricsJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -26,4 +28,12 @@ public class ModelTestMetricsJobCoreService {
|
||||
public void insertModelMetricsTest(List<Object[]> batchArgs) {
|
||||
modelTestMetricsJobRepository.insertModelMetricsTest(batchArgs);
|
||||
}
|
||||
|
||||
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
|
||||
return modelTestMetricsJobRepository.getTestMetricPackingInfo(modelId);
|
||||
}
|
||||
|
||||
public ModelTestFileName findModelTestFileNames(Long modelId) {
|
||||
return modelTestMetricsJobRepository.findModelTestFileNames(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -97,4 +98,8 @@ public class ModelTrainDetailCoreService {
|
||||
public ModelBestEpoch getModelTrainBestEpoch(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainBestEpoch(uuid);
|
||||
}
|
||||
|
||||
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
|
||||
return modelDetailRepository.getModelTrainFileInfo(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.kamco.cd.training.postgres.repository.train.ModelTrainJobRepository;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -104,4 +105,16 @@ public class ModelTrainJobCoreService {
|
||||
job.setStatusCd("STOPPED");
|
||||
job.setFinishedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateEpoch(String containerName, Integer epoch) {
|
||||
ModelTrainJobEntity job =
|
||||
modelTrainJobRepository
|
||||
.findByContainerName(containerName)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Job not found: " + containerName));
|
||||
|
||||
job.setCurrentEpoch(epoch);
|
||||
|
||||
if (Objects.equals(job.getTotalEpoch(), epoch)) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,4 +29,9 @@ public class ModelTrainMetricsJobCoreService {
|
||||
public void insertModelMetricsValidation(List<Object[]> batchArgs) {
|
||||
modelTrainMetricsJobRepository.insertModelMetricsValidation(batchArgs);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
|
||||
modelTrainMetricsJobRepository.updateModelSelectedBestEpoch(modelId, epoch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,9 +65,9 @@ public class ModelTrainMngCoreService {
|
||||
*/
|
||||
public void deleteModel(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
entity.setDelYn(true);
|
||||
entity.setUpdatedDttm(ZonedDateTime.now());
|
||||
entity.setUpdatedUid(userUtil.getId());
|
||||
@@ -83,15 +83,19 @@ public class ModelTrainMngCoreService {
|
||||
ModelMasterEntity entity = new ModelMasterEntity();
|
||||
ModelHyperParamEntity hyperParamEntity = new ModelHyperParamEntity();
|
||||
|
||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||
// 최적화 파라미터는 모델 type의 디폴트사용
|
||||
if (HyperParamSelectType.OPTIMIZED.getId().equals(addReq.getHyperParamType())) {
|
||||
ModelType modelType = ModelType.getValueData(addReq.getModelNo());
|
||||
hyperParamEntity = hyperParamRepository.getHyperparamByType(modelType).orElse(null);
|
||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
hyperParamEntity =
|
||||
hyperParamRepository.getHyperparamByType(modelType).stream()
|
||||
.filter(e -> e.getIsDefault() == Boolean.TRUE)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
// hyperParamEntity = hyperParamRepository.findByHyperVer("HPs_0001").orElse(null);
|
||||
|
||||
} else {
|
||||
hyperParamEntity =
|
||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||
hyperParamRepository.findHyperParamByUuid(addReq.getHyperUuid()).orElse(null);
|
||||
}
|
||||
|
||||
if (hyperParamEntity == null || hyperParamEntity.getHyperVer() == null) {
|
||||
@@ -99,24 +103,16 @@ public class ModelTrainMngCoreService {
|
||||
}
|
||||
|
||||
String modelVer =
|
||||
String.join(
|
||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||
String.join(
|
||||
".", addReq.getModelNo(), hyperParamEntity.getHyperVer(), entity.getUuid().toString());
|
||||
entity.setModelVer(modelVer);
|
||||
entity.setHyperParamId(hyperParamEntity.getId());
|
||||
entity.setModelNo(addReq.getModelNo());
|
||||
entity.setTrainType(addReq.getTrainType()); // 일반, 전이
|
||||
entity.setBeforeModelId(addReq.getBeforeModelId());
|
||||
|
||||
if (addReq.getIsStart()) {
|
||||
entity.setModelStep((short) 1);
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
entity.setStep1State(TrainStatusType.IN_PROGRESS.getId());
|
||||
} else {
|
||||
entity.setStatusCd(TrainStatusType.READY.getId());
|
||||
entity.setStep1State(TrainStatusType.READY.getId());
|
||||
}
|
||||
entity.setStatusCd(TrainStatusType.READY.getId());
|
||||
entity.setStep1State(TrainStatusType.READY.getId());
|
||||
|
||||
entity.setCreatedUid(userUtil.getId());
|
||||
ModelMasterEntity resultEntity = modelMngRepository.save(entity);
|
||||
@@ -132,7 +128,7 @@ public class ModelTrainMngCoreService {
|
||||
* data set 저장
|
||||
*
|
||||
* @param modelId 저장한 모델 학습 id
|
||||
* @param addReq 요청 파라미터
|
||||
* @param addReq 요청 파라미터
|
||||
*/
|
||||
public void saveModelDataset(Long modelId, ModelTrainMngDto.AddReq addReq) {
|
||||
TrainingDataset dataset = addReq.getTrainingDataset();
|
||||
@@ -163,14 +159,17 @@ public class ModelTrainMngCoreService {
|
||||
* @param modelId
|
||||
* @param req
|
||||
*/
|
||||
@Transactional
|
||||
public void updateModelMaster(Long modelId, ModelTrainMngDto.UpdateReq req) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
// 임시폴더 UID업데이트
|
||||
if (req.getRequestPath() != null && !req.getRequestPath().isEmpty()) {
|
||||
entity.setRequestPath(req.getRequestPath());
|
||||
}
|
||||
// TODO 삭제예정
|
||||
|
||||
if (req.getResponsePath() != null && !req.getResponsePath().isEmpty()) {
|
||||
entity.setRequestPath(req.getResponsePath());
|
||||
@@ -180,7 +179,7 @@ public class ModelTrainMngCoreService {
|
||||
/**
|
||||
* 모델 데이터셋 mapping 테이블 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param modelId 모델학습 id
|
||||
* @param datasetList 선택한 data set
|
||||
*/
|
||||
public void saveModelDatasetMap(Long modelId, List<Long> datasetList) {
|
||||
@@ -197,7 +196,7 @@ public class ModelTrainMngCoreService {
|
||||
* 모델학습 config 저장
|
||||
*
|
||||
* @param modelId 모델학습 id
|
||||
* @param req 요청 파라미터
|
||||
* @param req 요청 파라미터
|
||||
* @return
|
||||
*/
|
||||
public void saveModelConfig(Long modelId, ModelTrainMngDto.ModelConfig req) {
|
||||
@@ -217,7 +216,7 @@ public class ModelTrainMngCoreService {
|
||||
/**
|
||||
* 데이터셋 매핑 생성
|
||||
*
|
||||
* @param modelUid 모델 UID
|
||||
* @param modelUid 모델 UID
|
||||
* @param datasetIds 데이터셋 ID 목록
|
||||
*/
|
||||
public void createDatasetMappings(Long modelUid, List<Long> datasetIds) {
|
||||
@@ -239,8 +238,8 @@ public class ModelTrainMngCoreService {
|
||||
public ModelMasterEntity findByUuid(UUID uuid) {
|
||||
try {
|
||||
return modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new BadRequestException("잘못된 UUID 형식입니다: " + uuid);
|
||||
}
|
||||
@@ -254,9 +253,9 @@ public class ModelTrainMngCoreService {
|
||||
*/
|
||||
public Long findModelIdByUuid(UUID uuid) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
modelMngRepository
|
||||
.findByUuid(uuid)
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
return entity.getId();
|
||||
}
|
||||
|
||||
@@ -269,8 +268,8 @@ public class ModelTrainMngCoreService {
|
||||
public ModelConfigDto.Basic findModelConfigByModelId(UUID uuid) {
|
||||
ModelMasterEntity modelEntity = findByUuid(uuid);
|
||||
return modelConfigRepository
|
||||
.findModelConfigByModelId(modelEntity.getId())
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
.findModelConfigByModelId(modelEntity.getId())
|
||||
.orElseThrow(() -> new CustomApiException("NOT_FOUND_DATA", HttpStatus.NOT_FOUND));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -301,21 +300,19 @@ public class ModelTrainMngCoreService {
|
||||
*/
|
||||
public ModelTrainMngDto.Basic findModelById(Long id) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(id)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||
modelMngRepository
|
||||
.findById(id)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + id));
|
||||
return entity.toDto();
|
||||
}
|
||||
|
||||
/**
|
||||
* 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작
|
||||
*/
|
||||
/** 마스터를 IN_PROGRESS로 전환하고, 현재 실행 jobId를 연결 - UI/중단/상태조회 모두 currentAttemptId를 기준으로 동작 */
|
||||
@Transactional
|
||||
public void markInProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
master.setCurrentAttemptId(jobId);
|
||||
@@ -323,54 +320,46 @@ public class ModelTrainMngCoreService {
|
||||
// 필요하면 시작시간도 여기서 찍어줌
|
||||
}
|
||||
|
||||
/**
|
||||
* 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거
|
||||
*/
|
||||
/** 마지막 에러 메시지 초기화 - 재시작/새 실행 때 이전 에러 흔적 제거 */
|
||||
@Transactional
|
||||
public void clearLastError(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setLastError(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현
|
||||
*/
|
||||
/** 중단 처리(옵션) - cancel에서 쓰려고 하면 같이 구현 */
|
||||
@Transactional
|
||||
public void markStopped(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.STOPPED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* 완료 처리(옵션) - Worker가 성공 시 호출
|
||||
*/
|
||||
/** 완료 처리(옵션) - Worker가 성공 시 호출 */
|
||||
@Transactional
|
||||
public void markCompleted(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 1오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
/** step 1오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markError(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep1State(TrainStatusType.ERROR.getId());
|
||||
@@ -379,15 +368,13 @@ public class ModelTrainMngCoreService {
|
||||
master.setUpdatedDttm(ZonedDateTime.now());
|
||||
}
|
||||
|
||||
/**
|
||||
* step 2오류 처리(옵션) - Worker가 실패 시 호출
|
||||
*/
|
||||
/** step 2오류 처리(옵션) - Worker가 실패 시 호출 */
|
||||
@Transactional
|
||||
public void markStep2Error(Long modelId, String errorMessage) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
master.setStatusCd(TrainStatusType.ERROR.getId());
|
||||
master.setStep2State(TrainStatusType.ERROR.getId());
|
||||
@@ -399,9 +386,9 @@ public class ModelTrainMngCoreService {
|
||||
@Transactional
|
||||
public void markSuccess(Long modelId) {
|
||||
ModelMasterEntity master =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
// 모델 상태 완료 처리
|
||||
master.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
@@ -429,9 +416,9 @@ public class ModelTrainMngCoreService {
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep1InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep1StrtDttm(ZonedDateTime.now());
|
||||
@@ -449,9 +436,9 @@ public class ModelTrainMngCoreService {
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep2InProgress(Long modelId, Long jobId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.IN_PROGRESS.getId());
|
||||
entity.setStep2StrtDttm(ZonedDateTime.now());
|
||||
@@ -469,9 +456,9 @@ public class ModelTrainMngCoreService {
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep1Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep1State(TrainStatusType.COMPLETED.getId());
|
||||
@@ -488,9 +475,9 @@ public class ModelTrainMngCoreService {
|
||||
@Transactional(propagation = Propagation.REQUIRES_NEW)
|
||||
public void markStep2Success(Long modelId) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setStatusCd(TrainStatusType.COMPLETED.getId());
|
||||
entity.setStep2State(TrainStatusType.COMPLETED.getId());
|
||||
@@ -501,9 +488,9 @@ public class ModelTrainMngCoreService {
|
||||
|
||||
public void updateModelMasterBestEpoch(Long modelId, int epoch) {
|
||||
ModelMasterEntity entity =
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
modelMngRepository
|
||||
.findById(modelId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Model not found: " + modelId));
|
||||
|
||||
entity.setBestEpoch(epoch);
|
||||
}
|
||||
|
||||
@@ -316,7 +316,6 @@ public class ModelHyperParamEntity {
|
||||
@Enumerated(EnumType.STRING)
|
||||
private ModelType modelType;
|
||||
|
||||
|
||||
@Column(name = "default_param")
|
||||
private Boolean isDefault = false;
|
||||
|
||||
@@ -395,8 +394,7 @@ public class ModelHyperParamEntity {
|
||||
// -------------------------
|
||||
this.gpuCnt,
|
||||
this.gpuIds,
|
||||
this.masterPort
|
||||
, this.isDefault
|
||||
);
|
||||
this.masterPort,
|
||||
this.isDefault);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +127,7 @@ public class ModelMasterEntity {
|
||||
this.statusCd,
|
||||
this.trainType,
|
||||
this.modelNo,
|
||||
this.currentAttemptId);
|
||||
this.currentAttemptId,
|
||||
this.requestPath);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,4 +24,6 @@ public interface DatasetRepositoryCustom {
|
||||
Long insertDatasetMngData(DatasetMngRegDto mngRegDto);
|
||||
|
||||
List<String> findDatasetUid(List<Long> datasetIds);
|
||||
|
||||
Long findDatasetByUidExistsCnt(String uid);
|
||||
}
|
||||
|
||||
@@ -247,4 +247,13 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom {
|
||||
public List<String> findDatasetUid(List<Long> datasetIds) {
|
||||
return queryFactory.select(dataset.uid).from(dataset).where(dataset.id.in(datasetIds)).fetch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long findDatasetByUidExistsCnt(String uid) {
|
||||
return queryFactory
|
||||
.select(dataset.id.count())
|
||||
.from(dataset)
|
||||
.where(dataset.uid.eq(uid))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.kamco.cd.training.common.enums.ModelType;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto;
|
||||
import com.kamco.cd.training.hyperparam.dto.HyperParamDto.SearchReq;
|
||||
import com.kamco.cd.training.postgres.entity.ModelHyperParamEntity;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Page;
|
||||
@@ -32,5 +33,5 @@ public interface HyperParamRepositoryCustom {
|
||||
|
||||
Page<HyperParamDto.List> findByHyperVerList(ModelType model, SearchReq req);
|
||||
|
||||
Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||
List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType);
|
||||
}
|
||||
|
||||
@@ -91,7 +91,9 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
Pageable pageable = req.toPageable();
|
||||
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||
if (model != null) {
|
||||
builder.and(modelHyperParamEntity.modelType.eq(model));
|
||||
}
|
||||
builder.and(modelHyperParamEntity.delYn.isFalse());
|
||||
|
||||
if (req.getHyperVer() != null && !req.getHyperVer().isEmpty()) {
|
||||
@@ -129,9 +131,11 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
Projections.constructor(
|
||||
HyperParamDto.List.class,
|
||||
modelHyperParamEntity.uuid,
|
||||
modelHyperParamEntity.modelType.as("model"),
|
||||
modelHyperParamEntity.hyperVer,
|
||||
modelHyperParamEntity.createdDttm,
|
||||
modelHyperParamEntity.lastUsedDttm,
|
||||
modelHyperParamEntity.memo,
|
||||
modelHyperParamEntity.m1UseCnt,
|
||||
modelHyperParamEntity.m2UseCnt,
|
||||
modelHyperParamEntity.m3UseCnt,
|
||||
@@ -183,12 +187,15 @@ public class HyperParamRepositoryImpl implements HyperParamRepositoryCustom {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||
return Optional.ofNullable(
|
||||
queryFactory
|
||||
public List<ModelHyperParamEntity> getHyperparamByType(ModelType modelType) {
|
||||
return queryFactory
|
||||
.select(modelHyperParamEntity)
|
||||
.from(modelHyperParamEntity)
|
||||
.where(modelHyperParamEntity.delYn.isFalse().and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.fetchOne());
|
||||
.where(
|
||||
modelHyperParamEntity
|
||||
.delYn
|
||||
.isFalse()
|
||||
.and(modelHyperParamEntity.modelType.eq(modelType)))
|
||||
.fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ public class ModelDatasetMappRepositoryImpl implements ModelDatasetMappRepositor
|
||||
|
||||
@Override
|
||||
public List<ModelDatasetMappEntity> findByModelUid(Long modelId) {
|
||||
queryFactory
|
||||
return queryFactory
|
||||
.select(modelDatasetMappEntity)
|
||||
.from(modelDatasetMappEntity)
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId));
|
||||
return List.of();
|
||||
.where(modelDatasetMappEntity.modelUid.eq(modelId))
|
||||
.fetch();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -34,4 +35,6 @@ public interface ModelDetailRepositoryCustom {
|
||||
List<ModelTestMetrics> getModelTestMetricResult(UUID uuid);
|
||||
|
||||
ModelBestEpoch getModelTrainBestEpoch(UUID uuid);
|
||||
|
||||
ModelFileInfo getModelTrainFileInfo(UUID uuid);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.kamco.cd.training.model.dto.ModelTrainDetailDto.DetailSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.HyperSummary;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.MappingDataset;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelBestEpoch;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelFileInfo;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTestMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelTrainMetrics;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainDetailDto.ModelValidationMetrics;
|
||||
@@ -269,4 +270,17 @@ public class ModelDetailRepositoryImpl implements ModelDetailRepositoryCustom {
|
||||
modelMetricsTrainEntity.epoch.eq(modelMasterEntity.getBestEpoch()))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelFileInfo getModelTrainFileInfo(UUID uuid) {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelFileInfo.class,
|
||||
modelMasterEntity.step2MetricSaveYn,
|
||||
modelMasterEntity.modelVer))
|
||||
.from(modelMasterEntity)
|
||||
.where(modelMasterEntity.uuid.eq(uuid))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,11 @@ public class ModelMngRepositoryImpl implements ModelMngRepositoryCustom {
|
||||
BooleanBuilder builder = new BooleanBuilder();
|
||||
|
||||
if (req.getStatus() != null && !req.getStatus().isEmpty()) {
|
||||
builder.and(modelMasterEntity.statusCd.eq(req.getStatus()));
|
||||
builder.and(
|
||||
modelMasterEntity
|
||||
.step1State
|
||||
.eq(req.getStatus())
|
||||
.or(modelMasterEntity.step2State.eq(req.getStatus())));
|
||||
}
|
||||
|
||||
if (req.getModelNo() != null && !req.getModelNo().isEmpty()) {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.util.List;
|
||||
|
||||
@@ -10,4 +12,8 @@ public interface ModelTestMetricsJobRepositoryCustom {
|
||||
List<ResponsePathDto> getTestMetricSaveNotYetModelIds();
|
||||
|
||||
void insertModelMetricsTest(List<Object[]> batchArgs);
|
||||
|
||||
ModelMetricJsonDto getTestMetricPackingInfo(Long modelId);
|
||||
|
||||
ModelTestFileName findModelTestFileNames(Long modelId);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package com.kamco.cd.training.postgres.repository.train;
|
||||
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTestEntity.modelMetricsTestEntity;
|
||||
import static com.kamco.cd.training.postgres.entity.QModelMetricsTrainEntity.modelMetricsTrainEntity;
|
||||
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.postgres.entity.ModelMetricsTestEntity;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.Properties;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import com.querydsl.core.types.Projections;
|
||||
import com.querydsl.jpa.impl.JPAQueryFactory;
|
||||
@@ -42,7 +47,10 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
ResponsePathDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.responsePath,
|
||||
modelMasterEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step2EndDttm.isNotNull(),
|
||||
@@ -67,4 +75,43 @@ public class ModelTestMetricsJobRepositoryImpl extends QuerydslRepositorySupport
|
||||
|
||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelMetricJsonDto getTestMetricPackingInfo(Long modelId) {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelMetricJsonDto.class,
|
||||
modelMasterEntity.modelNo,
|
||||
modelMasterEntity.modelVer,
|
||||
Projections.constructor(
|
||||
Properties.class,
|
||||
modelMetricsTestEntity.f1Score,
|
||||
modelMetricsTestEntity.precisions,
|
||||
modelMetricsTestEntity.recall,
|
||||
modelMetricsTestEntity.iou,
|
||||
modelMetricsTrainEntity.loss)))
|
||||
.from(modelMetricsTestEntity)
|
||||
.innerJoin(modelMasterEntity)
|
||||
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
|
||||
.innerJoin(modelMetricsTrainEntity)
|
||||
.on(
|
||||
modelMetricsTestEntity.model.eq(modelMetricsTrainEntity.model),
|
||||
modelMasterEntity.bestEpoch.eq(modelMetricsTrainEntity.epoch))
|
||||
.where(modelMetricsTestEntity.model.id.eq(modelId))
|
||||
.fetchOne();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelTestFileName findModelTestFileNames(Long modelId) {
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ModelTestFileName.class, modelMetricsTestEntity.model1, modelMasterEntity.modelVer))
|
||||
.from(modelMetricsTestEntity)
|
||||
.innerJoin(modelMasterEntity)
|
||||
.on(modelMetricsTestEntity.model.id.eq(modelMasterEntity.id))
|
||||
.where(modelMetricsTestEntity.model.id.eq(modelId))
|
||||
.fetchOne();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ public interface ModelTrainJobRepositoryCustom {
|
||||
int findMaxAttemptNo(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findLatestByModelId(Long modelId);
|
||||
|
||||
Optional<ModelTrainJobEntity> findByContainerName(String containerName);
|
||||
}
|
||||
|
||||
@@ -40,4 +40,18 @@ public class ModelTrainJobRepositoryImpl implements ModelTrainJobRepositoryCusto
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ModelTrainJobEntity> findByContainerName(String containerName) {
|
||||
QModelTrainJobEntity j = QModelTrainJobEntity.modelTrainJobEntity;
|
||||
|
||||
ModelTrainJobEntity job =
|
||||
queryFactory
|
||||
.selectFrom(j)
|
||||
.where(j.containerName.eq(containerName))
|
||||
.orderBy(j.id.desc())
|
||||
.fetchFirst();
|
||||
|
||||
return Optional.ofNullable(job);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,4 +12,6 @@ public interface ModelTrainMetricsJobRepositoryCustom {
|
||||
void updateModelMetricsTrainSaveYn(Long modelId, String stepNo);
|
||||
|
||||
void insertModelMetricsValidation(List<Object[]> batchArgs);
|
||||
|
||||
void updateModelSelectedBestEpoch(Long modelId, Integer epoch);
|
||||
}
|
||||
|
||||
@@ -29,7 +29,10 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
return queryFactory
|
||||
.select(
|
||||
Projections.constructor(
|
||||
ResponsePathDto.class, modelMasterEntity.id, modelMasterEntity.responsePath))
|
||||
ResponsePathDto.class,
|
||||
modelMasterEntity.id,
|
||||
modelMasterEntity.responsePath,
|
||||
modelMasterEntity.uuid))
|
||||
.from(modelMasterEntity)
|
||||
.where(
|
||||
modelMasterEntity.step1EndDttm.isNotNull(),
|
||||
@@ -79,4 +82,13 @@ public class ModelTrainMetricsJobRepositoryImpl extends QuerydslRepositorySuppor
|
||||
|
||||
jdbcTemplate.batchUpdate(sql, batchArgs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateModelSelectedBestEpoch(Long modelId, Integer epoch) {
|
||||
queryFactory
|
||||
.update(modelMasterEntity)
|
||||
.set(modelMasterEntity.bestEpoch, epoch)
|
||||
.where(modelMasterEntity.id.eq(modelId))
|
||||
.execute();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.kamco.cd.training.train;
|
||||
|
||||
import com.kamco.cd.training.config.api.ApiResponseDto;
|
||||
import com.kamco.cd.training.train.service.DataSetCountersService;
|
||||
import com.kamco.cd.training.train.service.TestJobService;
|
||||
import com.kamco.cd.training.train.service.TrainJobService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
@@ -12,6 +13,8 @@ import io.swagger.v3.oas.annotations.responses.ApiResponses;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
@@ -25,6 +28,7 @@ public class TrainApiController {
|
||||
|
||||
private final TrainJobService trainJobService;
|
||||
private final TestJobService testJobService;
|
||||
private final DataSetCountersService dataSetCountersService;
|
||||
|
||||
@Operation(summary = "학습 실행", description = "학습 실행 API")
|
||||
@ApiResponses(
|
||||
@@ -45,7 +49,9 @@ public class TrainApiController {
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
trainJobService.createTmpFile(uuid);
|
||||
trainJobService.enqueue(modelId);
|
||||
|
||||
return ApiResponseDto.ok("ok");
|
||||
}
|
||||
|
||||
@@ -186,4 +192,26 @@ public class TrainApiController {
|
||||
|
||||
return ApiResponseDto.ok(trainJobService.createTmpFile(uuid));
|
||||
}
|
||||
|
||||
@Operation(summary = "getCount", description = "getCount 서버 로그확인")
|
||||
@ApiResponses(
|
||||
value = {
|
||||
@ApiResponse(
|
||||
responseCode = "200",
|
||||
description = "데이터셋 tmp 파일생성 성공",
|
||||
content =
|
||||
@Content(
|
||||
mediaType = "application/json",
|
||||
schema = @Schema(implementation = String.class))),
|
||||
@ApiResponse(responseCode = "400", description = "잘못된 검색 조건", content = @Content),
|
||||
@ApiResponse(responseCode = "500", description = "서버 오류", content = @Content)
|
||||
})
|
||||
@GetMapping(path = "/counts/{uuid}", produces = MediaType.APPLICATION_JSON_VALUE)
|
||||
public ApiResponseDto<String> getCount(
|
||||
@Parameter(description = "uuid", example = "e22181eb-2ac4-4100-9941-d06efce25c49")
|
||||
@PathVariable
|
||||
UUID uuid) {
|
||||
Long modelId = trainJobService.getModelIdByUuid(uuid);
|
||||
return ApiResponseDto.ok(dataSetCountersService.getCount(modelId));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,4 +13,10 @@ public class EvalRunRequest {
|
||||
private String uuid;
|
||||
private int epoch; // best_changed_fscore_epoch_1.pth
|
||||
private Integer timeoutSeconds;
|
||||
private String datasetFolder;
|
||||
private String outputFolder;
|
||||
|
||||
public String getOutputFolder() {
|
||||
return this.outputFolder.toString();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package com.kamco.cd.training.train.dto;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import java.util.Properties;
|
||||
import java.util.UUID;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -17,5 +20,40 @@ public class ModelTrainMetricsDto {
|
||||
|
||||
private Long modelId;
|
||||
private String responsePath;
|
||||
private UUID uuid;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class ModelMetricJsonDto {
|
||||
|
||||
@JsonProperty("cd_model_type")
|
||||
private String cdModelType;
|
||||
|
||||
@JsonProperty("model_version")
|
||||
private String modelVersion;
|
||||
|
||||
private Properties properties;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class Properties {
|
||||
|
||||
@JsonProperty("f1_score")
|
||||
private Float f1Score;
|
||||
|
||||
private Float precision;
|
||||
private Float recall;
|
||||
private Float loss;
|
||||
private Double iou;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public static class ModelTestFileName {
|
||||
|
||||
private String bestEpochFileName;
|
||||
private String modelVersion;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Log4j2
|
||||
@RequiredArgsConstructor
|
||||
public class DataSetCountersService {
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
@Value("${train.docker.basePath}")
|
||||
private String trainBaseDir;
|
||||
|
||||
public String getCount(Long modelId) {
|
||||
ModelTrainMngDto.Basic basic = modelTrainMngCoreService.findModelById(modelId);
|
||||
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||
|
||||
StringBuilder allLogs = new StringBuilder();
|
||||
|
||||
try {
|
||||
// request 폴더
|
||||
for (String uid : uids) {
|
||||
Path path = Path.of(requestDir, uid);
|
||||
DatasetCounters counters = countTmpAfterBuild(path);
|
||||
allLogs.append(counters.prints(uid, "REQUEST")).append(System.lineSeparator());
|
||||
}
|
||||
|
||||
// tmp
|
||||
Path tmpPath = Path.of(trainBaseDir, "tmp", basic.getRequestPath());
|
||||
DatasetCounters counters2 = countTmpAfterBuild(tmpPath);
|
||||
allLogs
|
||||
.append(counters2.prints(basic.getRequestPath(), "TMP"))
|
||||
.append(System.lineSeparator());
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
|
||||
return allLogs.toString();
|
||||
}
|
||||
|
||||
private int countTif(Path dir) throws IOException {
|
||||
if (!Files.isDirectory(dir)) return 0;
|
||||
|
||||
try (var stream = Files.walk(dir)) {
|
||||
return (int)
|
||||
stream.filter(Files::isRegularFile).filter(p -> p.toString().endsWith(".tif")).count();
|
||||
}
|
||||
|
||||
/*
|
||||
대소문자 및 geojson 필요시
|
||||
* try (var stream = Files.walk(dir)) {
|
||||
return (int)
|
||||
stream
|
||||
.filter(Files::isRegularFile)
|
||||
.filter(p -> {
|
||||
String name = p.getFileName().toString().toLowerCase();
|
||||
return name.endsWith(".tif") || name.endsWith(".geojson");
|
||||
})
|
||||
.count();
|
||||
}
|
||||
* */
|
||||
}
|
||||
|
||||
public DatasetCounters countTmpAfterBuild(Path path) throws IOException {
|
||||
|
||||
// input1
|
||||
int in1Train = countTif(path.resolve("train/input1"));
|
||||
int in1Val = countTif(path.resolve("val/input1"));
|
||||
int in1Test = countTif(path.resolve("test/input1"));
|
||||
|
||||
// input2
|
||||
int in2Train = countTif(path.resolve("train/input2"));
|
||||
int in2Val = countTif(path.resolve("val/input2"));
|
||||
int in2Test = countTif(path.resolve("test/input2"));
|
||||
|
||||
List<DatasetCounter> input1List = new ArrayList<>();
|
||||
List<DatasetCounter> input2List = new ArrayList<>();
|
||||
|
||||
input1List.add(new DatasetCounter(in1Train, in1Test, in1Val));
|
||||
input2List.add(new DatasetCounter(in2Train, in2Test, in2Val));
|
||||
|
||||
return new DatasetCounters(input1List, input2List);
|
||||
}
|
||||
|
||||
@Getter
|
||||
public static class DatasetCounter {
|
||||
private int inputTrain = 0;
|
||||
private int inputTest = 0;
|
||||
private int inputVal = 0;
|
||||
|
||||
public DatasetCounter(int inputTrain, int inputTest, int inputVal) {
|
||||
this.inputTrain = inputTrain;
|
||||
this.inputTest = inputTest;
|
||||
this.inputVal = inputVal;
|
||||
}
|
||||
}
|
||||
|
||||
@Getter
|
||||
public static class DatasetCounters {
|
||||
private List<DatasetCounter> input1 = new ArrayList<>();
|
||||
private List<DatasetCounter> input2 = new ArrayList<>();
|
||||
|
||||
public DatasetCounters(List<DatasetCounter> input1, List<DatasetCounter> input2) {
|
||||
this.input1 = input1;
|
||||
this.input2 = input2;
|
||||
}
|
||||
|
||||
public String prints(String uuid, String type) {
|
||||
int train = 0, test = 0, val = 0;
|
||||
int train2 = 0, test2 = 0, val2 = 0;
|
||||
|
||||
for (DatasetCounter datasetCounter : input1) {
|
||||
train += datasetCounter.inputTrain;
|
||||
test += datasetCounter.inputTest;
|
||||
val += datasetCounter.inputVal;
|
||||
}
|
||||
|
||||
for (DatasetCounter datasetCounter : input2) {
|
||||
train2 += datasetCounter.inputTrain;
|
||||
test2 += datasetCounter.inputTest;
|
||||
val2 += datasetCounter.inputVal;
|
||||
}
|
||||
|
||||
log.info("======== UUID FOLDER COUNT {} : {}", type, uuid);
|
||||
log.info("input 1 = train : {} | val : {} | test : {} ", train, val, test);
|
||||
log.info("input 2 = train : {} | val : {} | test : {} ", train2, val2, test2);
|
||||
log.info(
|
||||
"*total* = train : {} | val : {} | test : {} ",
|
||||
train + train2,
|
||||
val + val2,
|
||||
test + test2);
|
||||
|
||||
return String.format(
|
||||
"======== UUID FOLDER COUNT %s : %s%n"
|
||||
+ "input 1 = train : %s | val : %s | test : %s%n"
|
||||
+ "input 2 = train : %s | val : %s | test : %s%n"
|
||||
+ "*total* = train : %s | val : %s | test : %s",
|
||||
type,
|
||||
uuid,
|
||||
train,
|
||||
val,
|
||||
test,
|
||||
train2,
|
||||
val2,
|
||||
test2,
|
||||
train + train2,
|
||||
val + val2,
|
||||
test + test2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.EvalRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import com.kamco.cd.training.train.dto.TrainRunResult;
|
||||
@@ -9,14 +10,17 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.log4j.Log4j2;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Log4j2
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DockerTrainService {
|
||||
|
||||
// 실행할 Docker 이미지명
|
||||
@@ -43,6 +47,8 @@ public class DockerTrainService {
|
||||
@Value("${train.docker.ipcHost:true}")
|
||||
private boolean ipcHost;
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
|
||||
/** Docker 학습 컨테이너를 동기 실행 - 요청 스레드에서 docker run 실행 - 컨테이너 종료까지 대기 - stdout/stderr 로그 수집 후 결과 반환 */
|
||||
public TrainRunResult runTrainSync(TrainRunRequest req, String containerName) throws Exception {
|
||||
|
||||
@@ -56,12 +62,42 @@ public class DockerTrainService {
|
||||
ProcessBuilder pb = new ProcessBuilder(cmd);
|
||||
pb.redirectErrorStream(true);
|
||||
|
||||
Process p = pb.start();
|
||||
|
||||
// 로그는 별도 스레드에서 읽기 (메인 스레드가 readLine에 안 걸리게)
|
||||
StringBuilder logBuilder = new StringBuilder();
|
||||
Process p = pb.start();
|
||||
|
||||
Pattern epochPattern = Pattern.compile("(?i)\\bepoch\\s*\\[?(\\d+)\\s*/\\s*(\\d+)\\]?\\b");
|
||||
log.info("[TRAIN-BOOT] docker run started. container={}", containerName);
|
||||
|
||||
try {
|
||||
log.info("[TRAIN-BOOT] pid={}", p.pid()); // Java 9+
|
||||
} catch (Throwable ignore) {
|
||||
}
|
||||
|
||||
try {
|
||||
// 바로 죽었는지 100ms만 체크
|
||||
if (p.waitFor(100, TimeUnit.MILLISECONDS)) {
|
||||
int exit = p.exitValue();
|
||||
String earlyLogs;
|
||||
synchronized (logBuilder) {
|
||||
earlyLogs = logBuilder.toString();
|
||||
}
|
||||
log.error(
|
||||
"[TRAIN-BOOT] docker run exited immediately. container={} exit={}",
|
||||
containerName,
|
||||
exit);
|
||||
log.error("[TRAIN-BOOT] early logs:\n{}", earlyLogs);
|
||||
} else {
|
||||
log.info("[TRAIN-BOOT] docker run is still running. container={}", containerName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("[TRAIN-BOOT] early-exit check failed: {}", e.toString(), e);
|
||||
}
|
||||
|
||||
Pattern epochPattern = Pattern.compile("Epoch\\(train\\)\\s+\\[(\\d+)\\]\\[(\\d+)/(\\d+)\\]");
|
||||
|
||||
// 너무 잦은 업데이트 방지용
|
||||
AtomicInteger lastEpoch = new AtomicInteger(0);
|
||||
AtomicInteger lastIter = new AtomicInteger(0);
|
||||
|
||||
Thread logThread =
|
||||
new Thread(
|
||||
@@ -73,23 +109,40 @@ public class DockerTrainService {
|
||||
String line;
|
||||
while ((line = br.readLine()) != null) {
|
||||
|
||||
// 1) 로그 누적
|
||||
synchronized (logBuilder) {
|
||||
logBuilder.append(line).append('\n');
|
||||
}
|
||||
|
||||
// 2) epoch 감지 + DB 업데이트
|
||||
Matcher m = epochPattern.matcher(line);
|
||||
if (m.find()) {
|
||||
int currentEpoch = Integer.parseInt(m.group(1));
|
||||
int totalEpoch = Integer.parseInt(m.group(2));
|
||||
int epoch = Integer.parseInt(m.group(1));
|
||||
int iter = Integer.parseInt(m.group(2));
|
||||
int totalIter = Integer.parseInt(m.group(3));
|
||||
|
||||
log.info("[EPOCH] container={} {}/{}", containerName, currentEpoch, totalEpoch);
|
||||
// (선택) maxEpochs는 req에서 알고 있으니 req.getEpochs() 같은 걸로 사용
|
||||
int maxEpochs = req.getEpochs() != null ? req.getEpochs() : 0;
|
||||
|
||||
// TODO 실행중인 에폭 저장 필요하면 만들어야함
|
||||
// TODO 하지만 여기서 트랜젝션 걸리는 db 작업하면 안좋다고하는데..?
|
||||
// modelTrainMngCoreService.updateCurrentEpoch(modelId,
|
||||
// currentEpoch, totalEpoch);
|
||||
// 쓰로틀링: 에폭 끝 or 10 iter마다
|
||||
boolean shouldUpdate = (iter == totalIter) || (iter % 10 == 0);
|
||||
|
||||
// 중복 방지
|
||||
if (shouldUpdate) {
|
||||
int prevEpoch = lastEpoch.get();
|
||||
int prevIter = lastIter.get();
|
||||
if (epoch != prevEpoch || iter != prevIter) {
|
||||
lastEpoch.set(epoch);
|
||||
lastIter.set(iter);
|
||||
|
||||
log.info(
|
||||
"[TRAIN] container={} epoch={} iter={}/{}",
|
||||
containerName,
|
||||
epoch,
|
||||
iter,
|
||||
totalIter);
|
||||
|
||||
modelTrainJobCoreService.updateEpoch(containerName, epoch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
@@ -97,21 +150,6 @@ public class DockerTrainService {
|
||||
}
|
||||
},
|
||||
"train-log-" + containerName);
|
||||
// new Thread(
|
||||
// () -> {
|
||||
// try (BufferedReader br =
|
||||
// new BufferedReader(
|
||||
// new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) {
|
||||
// String line;
|
||||
// while ((line = br.readLine()) != null) {
|
||||
// synchronized (log) {
|
||||
// log.append(line).append('\n');
|
||||
// }
|
||||
// }
|
||||
// } catch (Exception ignored) {
|
||||
// }
|
||||
// },
|
||||
// "train-log-" + containerName);
|
||||
|
||||
logThread.setDaemon(true);
|
||||
logThread.start();
|
||||
@@ -169,7 +207,7 @@ public class DockerTrainService {
|
||||
|
||||
// 컨테이너 이름 지정
|
||||
c.add("--name");
|
||||
c.add(containerName + "-" + req.getUuid().substring(0, 8));
|
||||
c.add(containerName);
|
||||
|
||||
// 실행 종료 시 자동 삭제
|
||||
c.add("--rm");
|
||||
@@ -206,7 +244,7 @@ public class DockerTrainService {
|
||||
|
||||
// 요청/결과 디렉토리 볼륨 마운트
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("/home/kcomu/data" + "/tmp:/data");
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
@@ -264,15 +302,16 @@ public class DockerTrainService {
|
||||
|
||||
// ===== Augmentation =====
|
||||
addArg(c, "--rot-prob", req.getRotProb());
|
||||
addArg(c, "--rot-degree", req.getRotDegree());
|
||||
// addArg(c, "--rot-degree", req.getRotDegree()); // TODO AI 수정되면 주석 해제
|
||||
addArg(c, "--flip-prob", req.getFlipProb());
|
||||
addArg(c, "--exchange-prob", req.getExchangeProb());
|
||||
addArg(c, "--brightness-delta", req.getBrightnessDelta());
|
||||
addArg(c, "--contrast-range", req.getContrastRange());
|
||||
addArg(c, "--saturation-range", req.getSaturationRange());
|
||||
// addArg(c, "--contrast-range", req.getContrastRange()); // TODO AI 수정되면 주석 해제
|
||||
// addArg(c, "--saturation-range", req.getSaturationRange()); // TODO AI 수정되면 주석 해제
|
||||
addArg(c, "--hue-delta", req.getHueDelta());
|
||||
|
||||
addArg(c, "--resume-from", req.getResumeFrom());
|
||||
addArg(c, "--save-interval", 1);
|
||||
return c;
|
||||
}
|
||||
|
||||
@@ -376,30 +415,28 @@ public class DockerTrainService {
|
||||
|
||||
c.add("docker");
|
||||
c.add("run");
|
||||
c.add("--name");
|
||||
c.add(containerName + "=" + req.getUuid().substring(0, 8));
|
||||
c.add("--rm");
|
||||
|
||||
c.add("--gpus");
|
||||
c.add("all");
|
||||
if (ipcHost) c.add("--ipc=host");
|
||||
c.add("--ipc=host");
|
||||
c.add("--shm-size=" + shmSize);
|
||||
|
||||
c.add("-v");
|
||||
c.add(requestDir + ":/data");
|
||||
c.add("/home/kcomu/data" + "/tmp:/data");
|
||||
|
||||
c.add("-v");
|
||||
c.add(responseDir + ":/checkpoints");
|
||||
|
||||
c.add(image);
|
||||
c.add("kamco-cd-train:latest");
|
||||
|
||||
c.add("python");
|
||||
c.add("/workspace/change-detection-code/run_evaluation_pipeline.py");
|
||||
|
||||
c.add("--dataset_dir");
|
||||
c.add("/data/" + uuid);
|
||||
addArg(c, "--dataset-folder", req.getDatasetFolder());
|
||||
addArg(c, "--output-folder", req.getOutputFolder());
|
||||
|
||||
c.add("--model");
|
||||
c.add("/checkpoints/" + uuid + "/" + modelFile);
|
||||
c.add("--epoch");
|
||||
c.add(modelFile);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||
import com.kamco.cd.training.postgres.core.ModelTestMetricsJobCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelMetricJsonDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ModelTestFileName;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainMetricsDto.ResponsePathDto;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
@@ -27,6 +39,13 @@ public class ModelTestMetricsJobService {
|
||||
@Value("${spring.profiles.active}")
|
||||
private String profile;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
private String responseDir;
|
||||
|
||||
@Value("${file.pt-path}")
|
||||
private String ptPathDir;
|
||||
|
||||
/**
|
||||
* 실행중인 profile
|
||||
*
|
||||
@@ -36,8 +55,8 @@ public class ModelTestMetricsJobService {
|
||||
return "local".equalsIgnoreCase(profile);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() {
|
||||
// @Scheduled(cron = "0 * * * * *")
|
||||
public void findTestValidMetricCsvFiles() throws IOException {
|
||||
// if (isLocalProfile()) {
|
||||
// return;
|
||||
// }
|
||||
@@ -51,7 +70,7 @@ public class ModelTestMetricsJobService {
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String testPath = modelInfo.getResponsePath() + "/metrics/test.csv";
|
||||
String testPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/test.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
@@ -96,7 +115,97 @@ public class ModelTestMetricsJobService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 패키징할 파일 만들기
|
||||
ModelMetricJsonDto jsonDto =
|
||||
modelTestMetricsJobCoreService.getTestMetricPackingInfo(modelInfo.getModelId());
|
||||
try {
|
||||
writeJsonFile(
|
||||
jsonDto,
|
||||
Paths.get(
|
||||
responseDir
|
||||
+ "/"
|
||||
+ modelInfo.getUuid()
|
||||
+ "/"
|
||||
+ jsonDto.getModelVersion()
|
||||
+ ".json"));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||
|
||||
ModelTestFileName fileInfo =
|
||||
modelTestMetricsJobCoreService.findModelTestFileNames(modelInfo.getModelId());
|
||||
|
||||
Path zipPath =
|
||||
Paths.get(
|
||||
responseDir + "/" + modelInfo.getUuid() + "/" + fileInfo.getModelVersion() + ".zip");
|
||||
Set<String> targetNames =
|
||||
Set.of(
|
||||
"model_config.py",
|
||||
fileInfo.getBestEpochFileName() + ".pth",
|
||||
fileInfo.getModelVersion() + ".json");
|
||||
|
||||
List<Path> files = new ArrayList<>();
|
||||
try (Stream<Path> s = Files.list(responsePath)) {
|
||||
files.addAll(
|
||||
s.filter(Files::isRegularFile)
|
||||
.filter(p -> targetNames.contains(p.getFileName().toString()))
|
||||
.collect(Collectors.toList()));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
try (Stream<Path> s = Files.list(Path.of(ptPathDir))) {
|
||||
files.addAll(
|
||||
s.filter(Files::isRegularFile)
|
||||
.limit(1) // yolov8_6th-6m.pt 파일 1개만
|
||||
.collect(Collectors.toList()));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
zipFiles(files, zipPath);
|
||||
modelTestMetricsJobCoreService.updateModelMetricsTrainSaveYn(modelInfo.getModelId(), "step2");
|
||||
}
|
||||
}
|
||||
|
||||
private void writeJsonFile(Object data, Path outputPath) throws IOException {
|
||||
|
||||
Path parent = outputPath.getParent();
|
||||
|
||||
if (parent != null) {
|
||||
Files.createDirectories(parent);
|
||||
}
|
||||
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
|
||||
try (OutputStream os =
|
||||
Files.newOutputStream(
|
||||
outputPath, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING)) {
|
||||
objectMapper.writeValue(os, data);
|
||||
}
|
||||
}
|
||||
|
||||
private void zipFiles(List<Path> files, Path zipPath) throws IOException {
|
||||
|
||||
Path parent = zipPath.getParent();
|
||||
if (parent != null) {
|
||||
Files.createDirectories(parent);
|
||||
}
|
||||
|
||||
try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(zipPath))) {
|
||||
|
||||
for (Path file : files) {
|
||||
|
||||
ZipEntry entry = new ZipEntry(file.getFileName().toString());
|
||||
zos.putNextEntry(entry);
|
||||
|
||||
Files.copy(file, zos);
|
||||
|
||||
zos.closeEntry();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,14 @@ import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.csv.CSVFormat;
|
||||
@@ -55,7 +60,7 @@ public class ModelTrainMetricsJobService {
|
||||
|
||||
for (ResponsePathDto modelInfo : modelIds) {
|
||||
|
||||
String trainPath = responseDir + "{uuid}/metrics/train.csv"; // TODO
|
||||
String trainPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/train.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(trainPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
@@ -80,7 +85,7 @@ public class ModelTrainMetricsJobService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
String validationPath = modelInfo.getResponsePath() + "/metrics/val.csv";
|
||||
String validationPath = responseDir + "/" + modelInfo.getUuid() + "/metrics/val.csv";
|
||||
try (BufferedReader reader =
|
||||
Files.newBufferedReader(Paths.get(validationPath), StandardCharsets.UTF_8); ) {
|
||||
|
||||
@@ -129,6 +134,34 @@ public class ModelTrainMetricsJobService {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
Path responsePath = Paths.get(responseDir + "/" + modelInfo.getUuid());
|
||||
Integer epoch = null;
|
||||
boolean exists;
|
||||
Pattern pattern = Pattern.compile("best_changed_fscore_epoch_(\\d+)\\.pth");
|
||||
|
||||
try (Stream<Path> s = Files.list(responsePath)) {
|
||||
epoch =
|
||||
s.filter(Files::isRegularFile)
|
||||
.map(
|
||||
p -> {
|
||||
Matcher matcher = pattern.matcher(p.getFileName().toString());
|
||||
if (matcher.matches()) {
|
||||
return Integer.parseInt(matcher.group(1)); // ← 숫자 부분 추출
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// best_changed_fscore_epoch_숫자.pth -> 숫자 값 가지고 와서 베스트 에폭에 업데이트 하기
|
||||
// modelTrainMetricsJobCoreService.updateModelSelectedBestEpoch(modelInfo.getModelId(),
|
||||
// epoch);
|
||||
|
||||
modelTrainMetricsJobCoreService.updateModelMetricsTrainSaveYn(
|
||||
modelInfo.getModelId(), "step1");
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
@@ -20,8 +20,8 @@ public class TestJobService {
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId, UUID uuid, int epoch) {
|
||||
@@ -29,13 +29,21 @@ public class TestJobService {
|
||||
// 마스터 확인
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
// 폴더 카운트
|
||||
dataSetCounters.getCount(modelId);
|
||||
|
||||
// best epoch 업데이트
|
||||
modelTrainMngCoreService.updateModelMasterBestEpoch(modelId, epoch);
|
||||
|
||||
// 파라미터 조회
|
||||
TrainRunRequest trainRunRequest = modelTrainMngCoreService.findTrainRunRequest(modelId);
|
||||
|
||||
Map<String, Object> params = new java.util.LinkedHashMap<>();
|
||||
params.put("jobType", "EVAL");
|
||||
params.put("uuid", String.valueOf(uuid));
|
||||
params.put("epoch", epoch);
|
||||
params.put("datasetFolder", trainRunRequest.getDatasetFolder());
|
||||
params.put("outputFolder", trainRunRequest.getOutputFolder());
|
||||
|
||||
int nextAttemptNo = modelTrainJobCoreService.findMaxAttemptNo(modelId) + 1;
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
package com.kamco.cd.training.train.service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.*;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TmpDatasetService {
|
||||
|
||||
@Value("${train.docker.requestDir}")
|
||||
private String requestDir;
|
||||
|
||||
@Value("${train.docker.basePath}")
|
||||
private String trainBaseDir;
|
||||
|
||||
public String buildTmpDatasetSymlink(String uid, List<String> datasetUids) throws IOException {
|
||||
|
||||
log.info("========== buildTmpDatasetHardlink START ==========");
|
||||
log.info("uid={}", uid);
|
||||
log.info("datasetUids={}", datasetUids);
|
||||
log.info("requestDir(raw)={}", requestDir);
|
||||
|
||||
Path BASE = toPath(requestDir);
|
||||
Path tmp = Path.of(trainBaseDir, "tmp", uid);
|
||||
|
||||
log.info("BASE={}", BASE);
|
||||
log.info("BASE exists? {}", Files.isDirectory(BASE));
|
||||
log.info("tmp={}", tmp);
|
||||
|
||||
long noDir = 0, scannedDirs = 0, regularFiles = 0, hardlinksMade = 0;
|
||||
|
||||
// tmp 디렉토리 준비
|
||||
for (String type : List.of("train", "val", "test")) {
|
||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||
Path dir = tmp.resolve(type).resolve(part);
|
||||
Files.createDirectories(dir);
|
||||
log.info("createDirectories: {}", dir);
|
||||
}
|
||||
}
|
||||
|
||||
// 하드링크는 "같은 파일시스템"에서만 가능하므로 BASE/tmp가 같은 FS인지 미리 확인(권장)
|
||||
try {
|
||||
var baseStore = Files.getFileStore(BASE);
|
||||
var tmpStore = Files.getFileStore(tmp.getParent()); // BASE/tmp
|
||||
if (!baseStore.name().equals(tmpStore.name()) || !baseStore.type().equals(tmpStore.type())) {
|
||||
throw new IOException(
|
||||
"Hardlink requires same filesystem. baseStore="
|
||||
+ baseStore.name()
|
||||
+ "("
|
||||
+ baseStore.type()
|
||||
+ "), tmpStore="
|
||||
+ tmpStore.name()
|
||||
+ "("
|
||||
+ tmpStore.type()
|
||||
+ ")");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// FileStore 비교가 환경마다 애매할 수 있어서, 여기서는 경고만 주고 실제 createLink에서 최종 판단하게 둘 수도 있음.
|
||||
log.warn("FileStore check skipped/failed (will rely on createLink): {}", e.toString());
|
||||
}
|
||||
|
||||
for (String id : datasetUids) {
|
||||
Path srcRoot = BASE.resolve(id);
|
||||
log.info("---- dataset id={} srcRoot={} exists? {}", id, srcRoot, Files.isDirectory(srcRoot));
|
||||
|
||||
for (String type : List.of("train", "val", "test")) {
|
||||
for (String part : List.of("input1", "input2", "label", "label-json")) {
|
||||
|
||||
Path srcDir = srcRoot.resolve(type).resolve(part);
|
||||
if (!Files.isDirectory(srcDir)) {
|
||||
log.warn("SKIP (not directory): {}", srcDir);
|
||||
noDir++;
|
||||
continue;
|
||||
}
|
||||
|
||||
scannedDirs++;
|
||||
log.info("SCAN dir={}", srcDir);
|
||||
|
||||
try (DirectoryStream<Path> stream = Files.newDirectoryStream(srcDir)) {
|
||||
for (Path f : stream) {
|
||||
if (!Files.isRegularFile(f)) {
|
||||
log.debug("skip non-regular file: {}", f);
|
||||
continue;
|
||||
}
|
||||
|
||||
regularFiles++;
|
||||
|
||||
String dstName = id + "__" + f.getFileName();
|
||||
Path dst = tmp.resolve(type).resolve(part).resolve(dstName);
|
||||
|
||||
// dst가 남아있으면 삭제(심볼릭링크든 파일이든)
|
||||
if (Files.exists(dst) || Files.isSymbolicLink(dst)) {
|
||||
Files.delete(dst);
|
||||
log.debug("deleted existing: {}", dst);
|
||||
}
|
||||
|
||||
try {
|
||||
// 하드링크 생성 (dst가 새 파일로 생기지만 inode는 f와 동일)
|
||||
Files.createLink(dst, f);
|
||||
hardlinksMade++;
|
||||
log.debug("created hardlink: {} => {}", dst, f);
|
||||
} catch (IOException e) {
|
||||
// 여기서 바로 실패시키면 “tmp는 만들었는데 내용은 0개” 같은 상태를 방지할 수 있음
|
||||
log.error("FAILED create hardlink: {} => {}", dst, f, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hardlinksMade == 0) {
|
||||
throw new IOException(
|
||||
"No hardlinks created. regularFiles="
|
||||
+ regularFiles
|
||||
+ ", scannedDirs="
|
||||
+ scannedDirs
|
||||
+ ", noDir="
|
||||
+ noDir);
|
||||
}
|
||||
|
||||
log.info("tmp dataset created: {}", tmp);
|
||||
log.info(
|
||||
"summary: scannedDirs={}, noDir={}, regularFiles={}, hardlinksMade={}",
|
||||
scannedDirs,
|
||||
noDir,
|
||||
regularFiles,
|
||||
hardlinksMade);
|
||||
|
||||
return uid;
|
||||
}
|
||||
|
||||
private static Path toPath(String p) {
|
||||
if (p.startsWith("~/")) {
|
||||
return Paths.get(System.getProperty("user.home")).resolve(p.substring(2)).normalize();
|
||||
}
|
||||
return Paths.get(p).toAbsolutePath().normalize();
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,9 @@ package com.kamco.cd.training.train.service;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.kamco.cd.training.common.enums.TrainStatusType;
|
||||
import com.kamco.cd.training.model.dto.ModelTrainMngDto;
|
||||
import com.kamco.cd.training.model.service.TmpDatasetService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainJobCoreService;
|
||||
import com.kamco.cd.training.postgres.core.ModelTrainMngCoreService;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobDto;
|
||||
import com.kamco.cd.training.train.dto.ModelTrainJobQueuedEvent;
|
||||
import com.kamco.cd.training.train.dto.TrainRunRequest;
|
||||
import java.io.IOException;
|
||||
@@ -33,6 +33,7 @@ public class TrainJobService {
|
||||
private final ObjectMapper objectMapper;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final TmpDatasetService tmpDatasetService;
|
||||
private final DataSetCountersService dataSetCounters;
|
||||
|
||||
// 학습 결과가 저장될 호스트 디렉토리
|
||||
@Value("${train.docker.responseDir}")
|
||||
@@ -46,6 +47,9 @@ public class TrainJobService {
|
||||
@Transactional
|
||||
public Long enqueue(Long modelId) {
|
||||
|
||||
// 폴더 카운트
|
||||
dataSetCounters.getCount(modelId);
|
||||
|
||||
// 마스터 존재 확인(없으면 예외)
|
||||
modelTrainMngCoreService.findModelById(modelId);
|
||||
|
||||
@@ -139,7 +143,7 @@ public class TrainJobService {
|
||||
throw new IllegalStateException("이미 진행중입니다.");
|
||||
}
|
||||
|
||||
var lastJob =
|
||||
ModelTrainJobDto lastJob =
|
||||
modelTrainJobCoreService
|
||||
.findLatestByModelId(modelId)
|
||||
.orElseThrow(() -> new IllegalStateException("이전 실행 이력이 없습니다."));
|
||||
@@ -219,21 +223,21 @@ public class TrainJobService {
|
||||
UUID tmpUuid = UUID.randomUUID();
|
||||
String raw = tmpUuid.toString().toUpperCase().replace("-", "");
|
||||
|
||||
// MODELID 가져오기
|
||||
Long modelId = modelTrainMngCoreService.findModelIdByUuid(modelUuid);
|
||||
List<Long> datasetIds = modelTrainMngCoreService.findModelDatasetMapp(modelId);
|
||||
|
||||
List<String> uids = modelTrainMngCoreService.findDatasetUid(datasetIds);
|
||||
|
||||
try {
|
||||
// 데이터셋 심볼링크 생성
|
||||
Path path = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
String pathUid = tmpDatasetService.buildTmpDatasetSymlink(raw, uids);
|
||||
ModelTrainMngDto.UpdateReq updateReq = new ModelTrainMngDto.UpdateReq();
|
||||
updateReq.setRequestPath(path.toString());
|
||||
updateReq.setRequestPath(pathUid);
|
||||
// 학습모델을 수정한다.
|
||||
modelTrainMngCoreService.updateModelMaster(modelId, updateReq);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return modelUuid;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ public class TrainJobWorker {
|
||||
|
||||
private final ModelTrainJobCoreService modelTrainJobCoreService;
|
||||
private final ModelTrainMngCoreService modelTrainMngCoreService;
|
||||
private final ModelTrainMetricsJobService modelTrainMetricsJobService;
|
||||
private final ModelTestMetricsJobService modelTestMetricsJobService;
|
||||
private final DockerTrainService dockerTrainService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@@ -66,8 +68,16 @@ public class TrainJobWorker {
|
||||
modelTrainMngCoreService.markStep2InProgress(modelId, jobId);
|
||||
String uuid = String.valueOf(params.get("uuid"));
|
||||
int epoch = (int) params.get("epoch");
|
||||
String datasetFolder = String.valueOf(params.get("datasetFolder"));
|
||||
String outputFolder = String.valueOf(params.get("outputFolder"));
|
||||
|
||||
EvalRunRequest evalReq = new EvalRunRequest();
|
||||
evalReq.setUuid(uuid);
|
||||
evalReq.setEpoch(epoch);
|
||||
evalReq.setTimeoutSeconds(null);
|
||||
evalReq.setDatasetFolder(datasetFolder);
|
||||
evalReq.setOutputFolder(outputFolder);
|
||||
|
||||
EvalRunRequest evalReq = new EvalRunRequest(uuid, epoch, null);
|
||||
result = dockerTrainService.runEvalSync(evalReq, containerName);
|
||||
|
||||
} else {
|
||||
@@ -90,8 +100,10 @@ public class TrainJobWorker {
|
||||
|
||||
if (isEval) {
|
||||
modelTrainMngCoreService.markStep2Success(modelId);
|
||||
modelTestMetricsJobService.findTestValidMetricCsvFiles();
|
||||
} else {
|
||||
modelTrainMngCoreService.markStep1Success(modelId);
|
||||
modelTrainMetricsJobService.findTrainValidMetricCsvFiles();
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
@@ -58,11 +58,15 @@ file:
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
|
||||
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
|
||||
pt-FileName: yolov8_6th-6m.pt
|
||||
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:love_latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
image: kamco-cd-train:latest
|
||||
requestDir: /home/kcomu/data/request
|
||||
responseDir: /home/kcomu/data/response
|
||||
basePath: /home/kcomu/data
|
||||
containerPrefix: kamco-cd-train
|
||||
shmSize: 16g
|
||||
ipcHost: true
|
||||
|
||||
@@ -46,11 +46,27 @@ springdoc:
|
||||
member:
|
||||
init_password: kamco1234!
|
||||
|
||||
swagger:
|
||||
local-port: 9080
|
||||
|
||||
file:
|
||||
sync-root-dir: /app/original-images/
|
||||
sync-tmp-dir: ${file.sync-root-dir}tmp/
|
||||
sync-file-extention: tfw,tif
|
||||
|
||||
dataset-dir: /home/kcomu/data/request/
|
||||
dataset-tmp-dir: ${file.dataset-dir}tmp/
|
||||
|
||||
pt-path: /home/kcomu/data/response/v6-cls-checkpoints/
|
||||
pt-FileName: yolov8_6th-6m.pt
|
||||
|
||||
train:
|
||||
docker:
|
||||
image: "kamco-cd-train:latest"
|
||||
requestDir: "/home/kcomu/data/request"
|
||||
responseDir: "/home/kcomu/data/response"
|
||||
containerPrefix: "kamco-cd-train"
|
||||
shmSize: "16g"
|
||||
image: kamco-cd-train:latest
|
||||
requestDir: /home/kcomu/data/request
|
||||
responseDir: /home/kcomu/data/response
|
||||
basePath: /home/kcomu/data
|
||||
containerPrefix: kamco-cd-train
|
||||
shmSize: 16g
|
||||
ipcHost: true
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ logging:
|
||||
level:
|
||||
org:
|
||||
springframework:
|
||||
web: DEBUG
|
||||
security: DEBUG
|
||||
web: INFO
|
||||
security: INFO
|
||||
root: INFO
|
||||
# actuator
|
||||
management:
|
||||
|
||||
87
src/main/resources/static/download_progress_test.html
Normal file
87
src/main/resources/static/download_progress_test.html
Normal file
@@ -0,0 +1,87 @@
|
||||
<!doctype html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>학습데이터 ZIP 다운로드</title>
|
||||
</head>
|
||||
<body>
|
||||
<h3>학습데이터 ZIP 다운로드</h3>
|
||||
|
||||
UUID:
|
||||
<input id="uuid" value="95cb116c-380a-41c0-98d8-4d1142f15bbf" />
|
||||
<br><br>
|
||||
modelVer:
|
||||
<input id="moderVer" value="G2.HPs_0001.95cb116c-380a-41c0-98d8-4d1142f15bbf" />
|
||||
<br><br>
|
||||
|
||||
JWT Token:
|
||||
<input id="token" style="width:600px;" placeholder="Bearer 토큰 붙여넣기" />
|
||||
<br><br>
|
||||
|
||||
<button onclick="download()">다운로드</button>
|
||||
|
||||
<br><br>
|
||||
<progress id="bar" value="0" max="100" style="width:400px;"></progress>
|
||||
<div id="status"></div>
|
||||
|
||||
<script>
|
||||
async function download() {
|
||||
const uuid = document.getElementById("uuid").value.trim();
|
||||
const moderVer = document.getElementById("moderVer").value.trim();
|
||||
const token = document.getElementById("token").value.trim();
|
||||
|
||||
if (!uuid) {
|
||||
alert("UUID 입력하세요");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
alert("토큰 입력하세요");
|
||||
return;
|
||||
}
|
||||
|
||||
const url = `/api/models/download/${uuid}`;
|
||||
|
||||
const res = await fetch(url, {
|
||||
headers: {
|
||||
"Authorization": token.startsWith("Bearer ")
|
||||
? token
|
||||
: `Bearer ${token}`,
|
||||
"kamco-download-uuid": uuid
|
||||
}
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
document.getElementById("status").innerText =
|
||||
"실패: " + res.status;
|
||||
return;
|
||||
}
|
||||
|
||||
const total = parseInt(res.headers.get("Content-Length") || "0", 10);
|
||||
const reader = res.body.getReader();
|
||||
const chunks = [];
|
||||
let received = 0;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
chunks.push(value);
|
||||
received += value.length;
|
||||
|
||||
if (total) {
|
||||
document.getElementById("bar").value =
|
||||
(received / total) * 100;
|
||||
}
|
||||
}
|
||||
|
||||
const blob = new Blob(chunks);
|
||||
const a = document.createElement("a");
|
||||
a.href = URL.createObjectURL(blob);
|
||||
a.download = moderVer + ".zip";
|
||||
a.click();
|
||||
|
||||
document.getElementById("status").innerText = "완료 ✅";
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user