From 36961362a88de9770d5fa448c1e605b9ccbd200f Mon Sep 17 00:00:00 2001 From: AndrejPer Date: Tue, 11 Jun 2024 11:36:22 +0200 Subject: [PATCH] Adding support for Apple Silicon --- models.py | 2 +- train_dual_decoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 370281f..83118d0 100644 --- a/models.py +++ b/models.py @@ -11,7 +11,7 @@ import time import sys -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") def resnet_block(stride=1): layers = [] diff --git a/train_dual_decoder.py b/train_dual_decoder.py index 9a77bf8..73cead7 100644 --- a/train_dual_decoder.py +++ b/train_dual_decoder.py @@ -157,7 +157,7 @@ def load_checkpoint(checkpoint): parser.add_argument('--predict_bbox', dest='predict_bbox', default=False, action='store_true', help='Predict cell bbox') args = parser.parse_args() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") # sets device for model and PyTorch tensors cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead # Read word map