We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a8025b7 commit a2f65c1Copy full SHA for a2f65c1
intermediate_source/reinforcement_q_learning.py
@@ -408,7 +408,7 @@ def optimize_model():
408
# Compute a mask of non-final states and concatenate the batch elements
409
# (a final state would've been the one after which simulation ended)
410
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
411
- batch.next_state)), device=device, dtype=torch.uint8)
+ batch.next_state)), device=device, dtype=torch.bool)
412
non_final_next_states = torch.cat([s for s in batch.next_state
413
if s is not None])
414
state_batch = torch.cat(batch.state)
0 commit comments