Skip to content

Conversation

matteobettini
Copy link
Contributor

Description

This PR ddresses issue #790.

The changes replace the "reset_workers" flag (only deigned for ParallelEnvs wrapping environments with emty batch_size) with the "_reset" flag, which spans over all batch_size dimensions.

This allows to more precisely tell the wrapped environemnts which dimensions need to be reset.

In accordace to this, now the reset() methods on EnvBase and ParallelEnv only check that at least the indexes that were flagged to be reset are not done. Instead of checking assert not done.any().

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 6, 2023
@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 6, 2023

@vmoens
I will write a few tests for the non-regression of this feature.

In the meantime, if you could have a quick look at the changes I have made to collectors.py since I am not super familiar with that part of the codebase yet. I just want to make sure I didn't miss any subtleties and the logic remains unchanged.

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment but otherwise LGTM

Comment on lines 460 to 468
if tensordict is not None and "_reset" in tensordict.keys():
self._assert_tensordict_shape(tensordict)
_reset = tensordict.get("_reset")
else:
_reset = None

if (_reset is None and tensordict_reset.get("done").any()) or (
_reset is not None and tensordict_reset.get("done")[_reset].any()
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about

        if tensordict is not None:
            self._assert_tensordict_shape(tensordict)
        _reset = tensordict.get("_reset", None)

        if (_reset is None and tensordict_reset.get("done").any()) or (
            _reset is not None and tensordict_reset.get("done")[_reset].any()
        ):

Copy link
Contributor Author

@matteobettini matteobettini Jan 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_reset = tensordict.get("_reset", None)
this crashes when tensordict is None

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see but the CI seems broken, let me see what happened

@matteobettini
Copy link
Contributor Author

The current version of the changes adopts the logic of deleting the _reset flag right after it has been used. EnvBase, ParallelEnvand SerialEnv will delete the "_reset" flag, if present, right after resetting. This implies that external components like collectors just have to worry about setting the flag and not deleting it

@matteobettini
Copy link
Contributor Author

I added some tests, CI seems broken on some http stuff

@matteobettini
Copy link
Contributor Author

While finishing up this PR I stumbled upon an unexpecte behavior that I am not sure is intended.

Imagine having an EnvBase where the reset function returns "done" and "observation".

You wrap it in a ParallelEnv with 2 workers and first things first you call reset, asking to reset only the first worker. What you will get is observations for both workers, but the ones for the worker which was not reset is random, since that worker did not even call its reset methos (as intended).

Is this expected?

@vmoens
Copy link
Collaborator

vmoens commented Jan 8, 2023

While finishing up this PR I stumbled upon an unexpecte behavior that I am not sure is intended.

Imagine having an EnvBase where the reset function returns "done" and "observation".

You wrap it in a ParallelEnv with 2 workers and first things first you call reset, asking to reset only the first worker. What you will get is observations for both workers, but the ones for the worker which was not reset is random, since that worker did not even call its reset methos (as intended).

Is this expected?

Good point, but what would be the expected behaviour?

@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 8, 2023

Good point, but what would be the expected behaviour?

The rule I use in vmas is to reset only the ones that need resetting and call env.observation(agent) for all envs and all agents. Given that here such an observatio function is not available in all libs I think we have a few options:

Option 1: call reset() for all workers indeendently

This calls reset on all the workers and passes the "_reset" flag. This flag would be all Falses for the workers which didn't need resetting. The envs then have to look at the "_reset" flag and, if the flag is all False, only return the observation without resetting.

Cons

Every envioronment in torchrl has to introduce the logic to handle the _reset flag and this behavior. Might be asking too much

Option 2: Heterogenous tensordicts

Only return the observation and keys that are actually available. If some workers were not reset, their observation dimension will be 0 and the return tensordict will be heterogeneous.

Option 3: Use NaNs or other placeholders.

Instead of putting random values for workers which were not reset, use a convetion of values (nans, inf, zeros).

Option 4: leave as is

When reset() is called after a step() the value of observations for workers which were not reset is the last one available from the step. The only case where the behavior is undefined is when reset is called before step only for some workers. We could precisely define what happens in this case in the docs and say that one should only access the values of

_reset = torch.randint(0,2, size=(*parallel_env.batch_size,1), dtype=troch.bool)
reset_td = parallel_env.reset(TensorDict({"_reset":_reset}, batch_size=parallel_env.batch_size))
reset_td[_reset] # Only values a user should access

My opinion

In my opinion 3 and 4 are still the best and least disruptive. They are synonyms of padding, which I hate, but we live in a world were NestedTensors are not yet available. The users have to be smart though and careful to know which values will be meaningless

@vmoens whats your take?

@vmoens
Copy link
Collaborator

vmoens commented Jan 8, 2023

I agree with you that options 3 and 4 look better than others.
In general, while I see the issue, I would prefer not to adopt a solution that forces users that do not use parallel envs / multi-agent to reset envs in a way that is more complex than needed.

To build the fake_tensordict that serves as buffer for the parallel env, we use rand with no reason. We could perfectly use NaN, zeros or anything else.

Also, with the collectors we return a "mask" entry that represents what indices of the data is valid. It's a way to make sure that users have access to what was or wasn't the result of a reset. Would you consider something like that?

@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 8, 2023

That makes sense yes, we can leave as is. Does the mask in the collectors already take into account this problems when resetting parallel workers?

@vmoens vmoens added the Refactoring Refactoring of an existing feature label Jan 9, 2023
@vmoens
Copy link
Collaborator

vmoens commented Jan 9, 2023

That makes sense yes, we can leave as is. Does the mask in the collectors already take into account this problems when resetting parallel workers?

No it just keeps track of padding operations. I was mentioning this as a mechanism to track valid steps.

Another thing to consider is this: if your env supports sub-envs that have not been reset but exist (e.g. in some lib you must reset before doing anything else), you could just append a transform that does some step counting. If an env has been reset, the steps will be strictly greater (or greater or equal?) than 0. If a step is 0 (or -1?), no step has actually been done.

See this class

e.g.

base_env = MyEnv(..., n_envs=2)
env = TransformedEnv(base_env, StepCounter())
tensordict = env.reset(TensorDict({"_reset": [True, False]}, [2]))
print(tensordict["steps"])  # prints [0, -1]

@riiswa I don't think this behaviour is currently supported, but do you think it would make sense?

@riiswa
Copy link
Contributor

riiswa commented Jan 9, 2023

@riiswa I don't think this behaviour is currently supported, but do you think it would make sense?

Yes I think that make sense. It will be necessary to write the doc of the class well. Can you create an issue for this and assign me ?

@vmoens
Copy link
Collaborator

vmoens commented Jan 9, 2023

@matteobettini do we want to wait for the discussion mentioned here above to be resolved before merging or are we happy with the current state?

@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 9, 2023

I think this PR can be merged as standalone since it is mainly refactoring "reset_workers" to "_reset" and putting in place some mechanism to handle it. If users want to use "_reset", it will only work in envs that support it.

The only libs with sub_envs to my knowledge are vmas and brax. vmas supports it. Brax, not having the state during a call to the reset function, cannot return a state which is partially reset and partially the old one. To enable support for brax we would need to pass the state during reset. or leave as is and users will do

td = brax.reset()
for _  in range(n_rollout_samples)
    td = brax.step(td_with_action)
    _reset = td["done"]
    td["state"][_reset] = brax.reset()["state"][_reset]

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM a few minor issues to solve and we're good to go

test/test_env.py Outdated
TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device)
)
env.close()
if _reset.any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we make sure that this is reached (I agree there's a low prob that it isn't)
a few idead:

  • put a else: RuntimeError in the end
  • same as above + repeat the test X times (e.g. X=3) if failed with a different seed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Does it make sense now?

test/test_env.py Outdated
assert (td["next"]["observation"] == max_steps + 1).all()

_reset = torch.randint(
low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the last dim 1? With the latest tensordict version, the dims of a tensor can match the batch size of a tensordict.

Copy link
Contributor Author

@matteobettini matteobettini Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the convention of the done flag which has shape (*batch_size,1). I adopted this convetion for _reset also in all the PR files.

Copy link
Contributor Author

@matteobettini matteobettini Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For "done", we want the last dim to be 1 because we do this kind of thing:

value = reward + value * gamma * (1-done)

we want reward and done to have a last dim of 1 otherwise they will be casted to the size of value which is usually the output of a neural net like nn.Sequential(..., nn.Linear(..., 1)).
So either we squeeze value or we unsqueeze reward and done.
Squeezing requires more brain power from the users IMO, whereas unsqueezing can be hidden from them.

For "_reset" I don't think we need that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I have removed the last dim of 1 also from "step_count"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO StepCounter should have steps that match the tensordict batch size.
@riiswa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so like it is now in the pr, i agree

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future it might be worth considering forcing one of the two convenctions for all tensordicts imo. I.e. all keys of shape (batch_size, 1) have to be squeezed. or the opposite. I understand that for reward and done one way might be more comfortable, but I find it confusing that some keys and specs are squeezed and others are not

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep it's been under my radar for quite some time. It's a conflicting thing:
Before tensordict was designed such that you could not store tensors with a shape identical to the tensordict (you had to unsqueeze or it was done for your).
That was surprising to many users, so we dropped that.
But as mentioned above, for reward and done, they interact so much with neural nets that it's pain to work with batched data that does not end with a singleton dimension.
Maybe unsqueezing in the env is not the solution though, I don't know really. We could delegate that to the value functions and such. My main worry is that it may lead to silent errors, e.g.

done = torch.zeros(10)
next_value = torch.ones(10, 1)
value = done*next_value  # silently gives the wrong 10 x 10 tensor

it's not only a problem for us but also for the users. That's why I'm not sure we're doing them a favour by removing the unsqueeze.

But I'm more than happy to keep the conversation going.

@tcbegley was involved in some of this refactoring, he may have a 2 cents to share

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work
I love that the library becomes more and more compatible with multiagent!

@vmoens vmoens merged commit b845cf2 into pytorch:main Jan 11, 2023
@matteobettini matteobettini deleted the reset_flag branch January 11, 2023 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Refactoring Refactoring of an existing feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants