diff --git a/src/main/java/com/kamco/cd/training/dataset/dto/DatasetDto.java b/src/main/java/com/kamco/cd/training/dataset/dto/DatasetDto.java index 52db4c4..73cb7a9 100644 --- a/src/main/java/com/kamco/cd/training/dataset/dto/DatasetDto.java +++ b/src/main/java/com/kamco/cd/training/dataset/dto/DatasetDto.java @@ -254,6 +254,7 @@ public class DatasetDto { private Long wasteCnt; private Long landCoverCnt; + private Integer solarPanelCnt; public SelectDataSet( String modelNo, @@ -308,6 +309,29 @@ public class DatasetDto { this.containerCnt = containerCnt; } + public SelectDataSet( + String modelNo, + Long datasetId, + UUID uuid, + String dataType, + String title, + Long roundNo, + Integer compareYyyy, + Integer targetYyyy, + String memo, + Integer solarPanelCnt) { + this.datasetId = datasetId; + this.uuid = uuid; + this.dataType = dataType; + this.dataTypeName = getDataTypeName(dataType); + this.title = title; + this.roundNo = roundNo; + this.compareYyyy = compareYyyy; + this.targetYyyy = targetYyyy; + this.memo = memo; + this.solarPanelCnt = solarPanelCnt; + } + public String getDataTypeName(String groupTitleCd) { LearnDataType type = Enums.fromId(LearnDataType.class, groupTitleCd); return type == null ? null : type.getText(); diff --git a/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java index 07a0bdb..0439cff 100644 --- a/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java +++ b/src/main/java/com/kamco/cd/training/model/ModelTrainMngApiController.java @@ -1,6 +1,7 @@ package com.kamco.cd.training.model; import com.kamco.cd.training.common.dto.MonitorDto; +import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.common.service.SystemMonitorService; import com.kamco.cd.training.config.api.ApiResponseDto; import com.kamco.cd.training.dataset.dto.DatasetDto; @@ -143,9 +144,9 @@ public class ModelTrainMngApiController { @Parameter( description = "모델 구분", example = "", - schema = @Schema(allowableValues = {"G1", "G2", "G3"})) + schema = @Schema(allowableValues = {"G1", "G2", "G3", "G4"})) @RequestParam - String modelType, + ModelType modelType, @Parameter( description = "선택 구분", example = "", @@ -153,7 +154,7 @@ public class ModelTrainMngApiController { @RequestParam String selectType) { DatasetReq req = new DatasetReq(); - req.setModelNo(modelType); + req.setModelNo(modelType.getId()); req.setDataType(selectType); return ApiResponseDto.ok(modelTrainMngService.getDatasetSelectList(req)); } diff --git a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java index d4016e0..59f71b5 100644 --- a/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java +++ b/src/main/java/com/kamco/cd/training/model/service/ModelTrainMngService.java @@ -324,6 +324,8 @@ public class ModelTrainMngService { public List getDatasetSelectList(DatasetReq req) { if (req.getModelNo().equals(ModelType.G1.getId())) { return modelTrainMngCoreService.getDatasetSelectG1List(req); + } else if (req.getModelNo().equals(ModelType.G4.getId())) { + return modelTrainMngCoreService.getDatasetSelectG4List(req); } else { return modelTrainMngCoreService.getDatasetSelectG2G3List(req); } diff --git a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java index 081393d..3dc9b25 100644 --- a/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java +++ b/src/main/java/com/kamco/cd/training/postgres/core/ModelTrainMngCoreService.java @@ -337,6 +337,16 @@ public class ModelTrainMngCoreService { return datasetRepository.getDatasetTransferSelectG2G3List(modelId, modelNo); } + /** + * 데이터셋 G4 목록 + * + * @param req + * @return + */ + public List getDatasetSelectG4List(DatasetReq req) { + return datasetRepository.getDatasetSelectG4List(req); + } + // TODO 미사용 끝 /** diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryCustom.java b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryCustom.java index c7bd570..45d4096 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryCustom.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryCustom.java @@ -27,6 +27,8 @@ public interface DatasetRepositoryCustom { List getDatasetSelectG2G3List(DatasetReq req); + List getDatasetSelectG4List(DatasetReq req); + Long getDatasetMaxStage(int compareYyyy, int targetYyyy); Long insertDatasetMngData(DatasetMngRegDto mngRegDto); diff --git a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryImpl.java b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryImpl.java index dd1f2c4..8e75dc4 100644 --- a/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryImpl.java +++ b/src/main/java/com/kamco/cd/training/postgres/repository/dataset/DatasetRepositoryImpl.java @@ -4,6 +4,7 @@ import static com.kamco.cd.training.postgres.entity.QDatasetObjEntity.datasetObj import static com.kamco.cd.training.postgres.entity.QModelDatasetMappEntity.modelDatasetMappEntity; import static com.kamco.cd.training.postgres.entity.QModelMasterEntity.modelMasterEntity; +import com.kamco.cd.training.common.enums.DetectionClassification; import com.kamco.cd.training.common.enums.ModelType; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetMngRegDto; import com.kamco.cd.training.dataset.dto.DatasetDto.DatasetReq; @@ -104,10 +105,6 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom { builder.and(dataset.dataType.eq(req.getDataType())); } - if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) { - builder.and(dataset.dataType.eq(req.getDataType())); - } - if (req.getIds() != null) { builder.and(dataset.id.in(req.getIds())); } @@ -126,12 +123,15 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom { dataset.targetYyyy, dataset.memo, new CaseBuilder() - .when(datasetObjEntity.targetClassCd.eq("building")) + .when( + datasetObjEntity.targetClassCd.eq(DetectionClassification.BUILDING.getId())) .then(1) .otherwise(0) .sum(), new CaseBuilder() - .when(datasetObjEntity.targetClassCd.eq("container")) + .when( + datasetObjEntity.targetClassCd.eq( + DetectionClassification.CONTAINER.getId())) .then(1) .otherwise(0) .sum())) @@ -249,19 +249,31 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom { } // TODO 미사용 끝 + @Override public List getDatasetSelectG2G3List(DatasetReq req) { + String building = DetectionClassification.BUILDING.getId(); + String container = DetectionClassification.CONTAINER.getId(); + String waste = DetectionClassification.WASTE.getId(); + String solar = DetectionClassification.SOLAR.getId(); + BooleanBuilder builder = new BooleanBuilder(); builder.and(dataset.deleted.isFalse()); NumberExpression selectedCnt = null; NumberExpression wasteCnt = - datasetObjEntity.targetClassCd.when("waste").then(1L).otherwise(0L).sum(); + datasetObjEntity + .targetClassCd + .when(DetectionClassification.WASTE.getId()) + .then(1L) + .otherwise(0L) + .sum(); + // G1, G2, G4 제외 NumberExpression elseCnt = new CaseBuilder() - .when(datasetObjEntity.targetClassCd.notIn("building", "container", "waste")) + .when(datasetObjEntity.targetClassCd.notIn(building, container, waste, solar)) .then(1L) .otherwise(0L) .sum(); @@ -481,4 +493,53 @@ public class DatasetRepositoryImpl implements DatasetRepositoryCustom { .where(dataset.uid.eq(uid), dataset.deleted.isFalse()) .fetchOne(); } + + @Override + public List getDatasetSelectG4List(DatasetReq req) { + BooleanBuilder builder = new BooleanBuilder(); + + builder.and(dataset.deleted.isFalse()); + + if (StringUtils.isNotBlank(req.getDataType()) && !"CURRENT".equals(req.getDataType())) { + builder.and(dataset.dataType.eq(req.getDataType())); + } + + if (req.getIds() != null) { + builder.and(dataset.id.in(req.getIds())); + } + + return queryFactory + .select( + Projections.constructor( + SelectDataSet.class, + Expressions.constant(req.getModelNo()), + dataset.id, + dataset.uuid, + dataset.dataType, + dataset.title, + dataset.roundNo, + dataset.compareYyyy, + dataset.targetYyyy, + dataset.memo, + new CaseBuilder() + .when( + datasetObjEntity.targetClassCd.equalsIgnoreCase( + DetectionClassification.SOLAR.getId())) + .then(1) + .otherwise(0) + .sum())) + .from(dataset) + .leftJoin(datasetObjEntity) + .on(dataset.id.eq(datasetObjEntity.datasetUid)) + .where(builder) + .groupBy( + dataset.id, + dataset.uuid, + dataset.dataType, + dataset.title, + dataset.roundNo, + dataset.memo) + .orderBy(dataset.createdDttm.desc()) + .fetch(); + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java b/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java index d35422b..ffff705 100644 --- a/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java +++ b/src/main/java/com/kamco/cd/training/train/service/JobRecoveryOnStartupService.java @@ -384,7 +384,20 @@ public class JobRecoveryOnStartupService { return new OutputResult(false, "total-epoch-missing"); } - log.info("[RECOVERY] totalEpoch={}. jobId={}", totalEpoch, job.getId()); + Integer valInterval = extractValInterval(job).orElse(null); + if (valInterval == null || valInterval <= 0) { + log.warn( + "[RECOVERY] valInterval missing or invalid. jobId={}, valInterval={}", + job.getId(), + valInterval); + return new OutputResult(false, "val-interval-missing"); + } + + log.info( + "[RECOVERY] totalEpoch={}. valInterval={}. jobId={}", + totalEpoch, + valInterval, + job.getId()); // 3) val.csv 존재 확인 Path valCsv = outDir.resolve("val.csv"); @@ -396,14 +409,17 @@ public class JobRecoveryOnStartupService { // 4) val.csv 라인 수 확인 long lines = countNonHeaderLines(valCsv); + // expected = 실제 val 실행 횟수 + int expectedLines = totalEpoch / valInterval; + log.info( "[RECOVERY] val.csv lines counted. jobId={}, lines={}, expected={}", job.getId(), lines, - totalEpoch); + expectedLines); // 5) 완료 판정 - if (lines == totalEpoch) { + if (lines >= expectedLines) { log.info("[RECOVERY] outputs look COMPLETE. jobId={}", job.getId()); return new OutputResult(true, "ok"); } @@ -412,7 +428,7 @@ public class JobRecoveryOnStartupService { "[RECOVERY] val.csv line mismatch. jobId={}, lines={}, expected={}", job.getId(), lines, - totalEpoch); + expectedLines); return new OutputResult( false, "val.csv-lines-mismatch lines=" + lines + " expected=" + totalEpoch); @@ -530,4 +546,19 @@ public class JobRecoveryOnStartupService { return reason; } } + + /** paramsJson에서 valInterval 추출 */ + private Optional extractValInterval(ModelTrainJobDto job) { + Map params = job.getParamsJson(); + if (params == null) return Optional.empty(); + + Object v = params.get("valInterval"); + if (v == null) return Optional.empty(); + + try { + return Optional.of(Integer.parseInt(String.valueOf(v))); + } catch (Exception ignore) { + return Optional.empty(); + } + } } diff --git a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java index 985dabe..968e977 100644 --- a/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java +++ b/src/main/java/com/kamco/cd/training/train/service/TrainJobWorker.java @@ -132,7 +132,9 @@ public class TrainJobWorker { String failMsg = result.getStatus() + "\n" + result.getLogs(); log.info("training fail exitCode={} Msg ={}", result.getExitCode(), failMsg); - if (result.getExitCode() == -1 || result.getExitCode() == 143) { + if (result.getExitCode() == -1 + || result.getExitCode() == 143 + || result.getExitCode() == 137) { // 실패 처리 modelTrainJobCoreService.markPaused( jobId, result.getExitCode(), result.getStatus() + "\n" + result.getLogs());