diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index e13ff4faba6..2b89a773e96 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -224,6 +224,7 @@ def conv2d_size_out(size, kernel_size = 5, stride = 2): # Called with either one element to determine next action, or a batch # during optimization. Returns tensor([[left0exp,right0exp]...]). def forward(self, x): + x = x.to(device) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) @@ -273,7 +274,7 @@ def get_screen(): screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 screen = torch.from_numpy(screen) # Resize, and add a batch dimension (BCHW) - return resize(screen).unsqueeze(0).to(device) + return resize(screen).unsqueeze(0) env.reset()