-
Notifications
You must be signed in to change notification settings - Fork 232
Open
Description
I found the generate_attention_mask
function block_transformer.py
seems to calculate the attention mask incorrectly. Here an example:
According to the above printed attention mask diagram, the token at 656 (in t=1 obs_wrist
) should NOT attend to the token at 657 (in t=1 readout_action
). However, attention_mask[656, 657]
is True. You can reproduce this using jdb. It seems that get_token_metadata
function doesn't calculate the belonging group of tokens correctly.
(jdb) l
> /home/yushijinhun/octo/octo/octo/model/components/block_transformer.py(325)
mask = int(metadata_i.should_attend_to(metadata_j))
attention_mask[i, j] = mask
pad_attention_mask = self.generate_pad_attention_mask(
prefix_groups, timestep_groups
)
-> jax.debug.breakpoint()
attention_mask = jnp.logical_and(attention_mask, pad_attention_mask)
return attention_mask
(jdb) bt
Traceback:
File "/home/yushijinhun/octo/octo-experiment/test.py", line 11
actions = model.sample_actions(
File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 187
transformer_outputs = self.run_transformer(
File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 152
return self.module.apply(
File "/home/yushijinhun/octo/octo/octo/model/octo_module.py", line 249
prefix_outputs, timestep_outputs = BlockTransformer(self.transformer_kwargs)(
File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 172
attention_mask = self.generate_attention_mask(prefix_groups, timestep_groups)
File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 325
jax.debug.breakpoint()
(jdb) tokens_per_prefix_group
[16]
(jdb) tokens_per_timestep_group
[256, 64, 1]
(jdb) horizon
2
(jdb) tokens_for_prefix
16
(jdb) tokens_per_time_step
321
(jdb) total_tokens
658
(jdb) get_token_metadata(657) #### <--- Token 657 should belong to group "t=1 readout_action", NOT "t=1 obs_wrist"
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) get_token_metadata(656)
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) attention_mask[657, 656]
1
(jdb) attention_mask[656, 657] #### <--- This should be FALSE, group "t=1 obs_wrist" should NOT attend to group "t=1 readout_action"
1
Metadata
Metadata
Assignees
Labels
No labels