-
Notifications
You must be signed in to change notification settings - Fork 412
[BugFix] [Feature] "_reset" flag for env reset #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@vmoens In the meantime, if you could have a quick look at the changes I have made to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment but otherwise LGTM
torchrl/envs/common.py
Outdated
if tensordict is not None and "_reset" in tensordict.keys(): | ||
self._assert_tensordict_shape(tensordict) | ||
_reset = tensordict.get("_reset") | ||
else: | ||
_reset = None | ||
|
||
if (_reset is None and tensordict_reset.get("done").any()) or ( | ||
_reset is not None and tensordict_reset.get("done")[_reset].any() | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about
if tensordict is not None:
self._assert_tensordict_shape(tensordict)
_reset = tensordict.get("_reset", None)
if (_reset is None and tensordict_reset.get("done").any()) or (
_reset is not None and tensordict_reset.get("done")[_reset].any()
):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_reset = tensordict.get("_reset", None)
this crashes when tensordict is None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not see but the CI seems broken, let me see what happened
3297ad5
to
23a82fc
Compare
The current version of the changes adopts the logic of deleting the _reset flag right after it has been used. |
I added some tests, CI seems broken on some http stuff |
While finishing up this PR I stumbled upon an unexpecte behavior that I am not sure is intended. Imagine having an You wrap it in a Is this expected? |
Good point, but what would be the expected behaviour? |
The rule I use in vmas is to reset only the ones that need resetting and call Option 1: call
|
I agree with you that options 3 and 4 look better than others. To build the fake_tensordict that serves as buffer for the parallel env, we use rand with no reason. We could perfectly use NaN, zeros or anything else. Also, with the collectors we return a "mask" entry that represents what indices of the data is valid. It's a way to make sure that users have access to what was or wasn't the result of a reset. Would you consider something like that? |
That makes sense yes, we can leave as is. Does the mask in the collectors already take into account this problems when resetting parallel workers? |
No it just keeps track of padding operations. I was mentioning this as a mechanism to track valid steps. Another thing to consider is this: if your env supports sub-envs that have not been reset but exist (e.g. in some lib you must reset before doing anything else), you could just append a transform that does some step counting. If an env has been reset, the steps will be strictly greater (or greater or equal?) than 0. If a step is 0 (or -1?), no step has actually been done. See this class e.g. base_env = MyEnv(..., n_envs=2)
env = TransformedEnv(base_env, StepCounter())
tensordict = env.reset(TensorDict({"_reset": [True, False]}, [2]))
print(tensordict["steps"]) # prints [0, -1] @riiswa I don't think this behaviour is currently supported, but do you think it would make sense? |
Yes I think that make sense. It will be necessary to write the doc of the class well. Can you create an issue for this and assign me ? |
@matteobettini do we want to wait for the discussion mentioned here above to be resolved before merging or are we happy with the current state? |
I think this PR can be merged as standalone since it is mainly refactoring "reset_workers" to "_reset" and putting in place some mechanism to handle it. If users want to use "_reset", it will only work in envs that support it. The only libs with sub_envs to my knowledge are vmas and brax. vmas supports it. Brax, not having the state during a call to the reset function, cannot return a state which is partially reset and partially the old one. To enable support for brax we would need to pass the state during reset. or leave as is and users will do td = brax.reset()
for _ in range(n_rollout_samples)
td = brax.step(td_with_action)
_reset = td["done"]
td["state"][_reset] = brax.reset()["state"][_reset] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM a few minor issues to solve and we're good to go
test/test_env.py
Outdated
TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) | ||
) | ||
env.close() | ||
if _reset.any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we make sure that this is reached (I agree there's a low prob that it isn't)
a few idead:
- put a
else: RuntimeError
in the end - same as above + repeat the test X times (e.g. X=3) if failed with a different seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Does it make sense now?
test/test_env.py
Outdated
assert (td["next"]["observation"] == max_steps + 1).all() | ||
|
||
_reset = torch.randint( | ||
low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the last dim 1? With the latest tensordict version, the dims of a tensor can match the batch size of a tensordict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the convention of the done flag which has shape (*batch_size,1)
. I adopted this convetion for _reset also in all the PR files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See here for example https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L454
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or the one me and you did recently https://github.com/pytorch/rl/blob/main/torchrl/envs/vec_env.py#L1013
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For "done"
, we want the last dim to be 1 because we do this kind of thing:
value = reward + value * gamma * (1-done)
we want reward and done to have a last dim of 1 otherwise they will be casted to the size of value
which is usually the output of a neural net like nn.Sequential(..., nn.Linear(..., 1))
.
So either we squeeze value or we unsqueeze reward and done.
Squeezing requires more brain power from the users IMO, whereas unsqueezing can be hidden from them.
For "_reset"
I don't think we need that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I have removed the last dim of 1 also from "step_count"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO StepCounter should have steps that match the tensordict batch size.
@riiswa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so like it is now in the pr, i agree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future it might be worth considering forcing one of the two convenctions for all tensordicts imo. I.e. all keys of shape (batch_size, 1) have to be squeezed. or the opposite. I understand that for reward and done one way might be more comfortable, but I find it confusing that some keys and specs are squeezed and others are not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep it's been under my radar for quite some time. It's a conflicting thing:
Before tensordict was designed such that you could not store tensors with a shape identical to the tensordict (you had to unsqueeze or it was done for your).
That was surprising to many users, so we dropped that.
But as mentioned above, for reward and done, they interact so much with neural nets that it's pain to work with batched data that does not end with a singleton dimension.
Maybe unsqueezing in the env is not the solution though, I don't know really. We could delegate that to the value functions and such. My main worry is that it may lead to silent errors, e.g.
done = torch.zeros(10)
next_value = torch.ones(10, 1)
value = done*next_value # silently gives the wrong 10 x 10 tensor
it's not only a problem for us but also for the users. That's why I'm not sure we're doing them a favour by removing the unsqueeze.
But I'm more than happy to keep the conversation going.
@tcbegley was involved in some of this refactoring, he may have a 2 cents to share
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work
I love that the library becomes more and more compatible with multiagent!
Description
This PR ddresses issue #790.
The changes replace the "reset_workers" flag (only deigned for ParallelEnvs wrapping environments with emty batch_size) with the "_reset" flag, which spans over all batch_size dimensions.
This allows to more precisely tell the wrapped environemnts which dimensions need to be reset.
In accordace to this, now the
reset()
methods onEnvBase
andParallelEnv
only check that at least the indexes that were flagged to be reset are not done. Instead of checkingassert not done.any()
.