diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index 5a1625b..60ee90d 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -62,6 +62,9 @@ public class DockerTrainService { @Value("${hyper.parameter.gpu-ids}") private String hyperGpuIds; + @Value("${hyper.parameter.batch_size}") + private Integer batchSize; + private final ModelTrainJobCoreService modelTrainJobCoreService; /** @@ -293,7 +296,7 @@ public class DockerTrainService { addArg(c, "--batch-size", 2); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 } else { - addArg(c, "--batch-size", req.getBatchSize()); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 + addArg(c, "--batch-size", batchSize); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 } addArg(c, "--gpus", hyperGpus); // 학습서버 GPU 1개인 곳은 1이어야 함 addArg(c, "--gpu-ids", hyperGpuIds); // 학습서버 GPU 1개인 곳은 0이어야 함 diff --git a/src/main/resources/application-prod.yml b/src/main/resources/application-prod.yml index 1d0e96a..f1c0ad3 100644 --- a/src/main/resources/application-prod.yml +++ b/src/main/resources/application-prod.yml @@ -45,3 +45,4 @@ hyper: parameter: gpus: 4 gpu-ids: 0,1,2,3 + batch-size: 30 diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index d13e1df..bab9fef 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -82,4 +82,4 @@ hyper: parameter: gpus: 1 gpu-ids: 0 - + batch-size: 2