diff --git a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java index e52cdf5..4d9bec2 100644 --- a/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java +++ b/src/main/java/com/kamco/cd/training/train/service/DockerTrainService.java @@ -54,6 +54,9 @@ public class DockerTrainService { @Value("${train.docker.ipcHost:true}") private boolean ipcHost; + @Value("${spring.profiles.active}") + private String profile; + private final ModelTrainJobCoreService modelTrainJobCoreService; /** @@ -228,8 +231,11 @@ public class DockerTrainService { // GPU 전체 사용 c.add("--gpus"); - c.add("1"); // 학습서버 GPU 1개인 곳은 1이어야 함 - // c.add("all"); + if ("prod".equals(profile)) { + c.add("1"); // 학습서버 GPU 1개인 곳은 1이어야 함 + } else { + c.add("all"); + } // IPC host 사용 여부 if (ipcHost) { @@ -282,8 +288,12 @@ public class DockerTrainService { addArg(c, "--crop-size", req.getCropSize()); // addArg(c, "--batch-size", req.getBatchSize()); // addArg(c, "--gpu-ids", req.getGpuIds()); // null - addArg(c, "--batch-size", 2); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 - addArg(c, "--gpu-ids", "0"); // 학습서버 GPU 1개인 곳은 0이어야 함 + if ("prod".equals(profile)) { + addArg(c, "--batch-size", 2); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 + addArg(c, "--gpu-ids", "0"); // 학습서버 GPU 1개인 곳은 0이어야 함 + } else { + addArg(c, "--batch-size", req.getBatchSize()); // 학습서버 GPU 1개인 곳은 batch-size:2 까지만 가능 + } addArg(c, "--lr", req.getLearningRate()); addArg(c, "--backbone", req.getBackbone()); addArg(c, "--epochs", req.getEpochs()); @@ -448,8 +458,11 @@ public class DockerTrainService { c.add("run"); c.add("--rm"); c.add("--gpus"); - c.add("1"); // 학습서버 GPU 1개인 곳은 1이어야 함 - // c.add("all"); + if ("prod".equals(profile)) { + c.add("1"); // 학습서버 GPU 1개인 곳은 1이어야 함 + } else { + c.add("all"); + } c.add("--ipc=host"); c.add("--shm-size=" + shmSize);