Skip to content

Incorrect attention mask computation #75

@yushijinhun

Description

@yushijinhun

I found the generate_attention_mask function block_transformer.py seems to calculate the attention mask incorrectly. Here an example:

Attention mask debug info

According to the above printed attention mask diagram, the token at 656 (in t=1 obs_wrist) should NOT attend to the token at 657 (in t=1 readout_action). However, attention_mask[656, 657] is True. You can reproduce this using jdb. It seems that get_token_metadata function doesn't calculate the belonging group of tokens correctly.

(jdb) l
> /home/yushijinhun/octo/octo/octo/model/components/block_transformer.py(325)
                    mask = int(metadata_i.should_attend_to(metadata_j))
                    attention_mask[i, j] = mask
    
            pad_attention_mask = self.generate_pad_attention_mask(
                prefix_groups, timestep_groups
            )
->          jax.debug.breakpoint()
            attention_mask = jnp.logical_and(attention_mask, pad_attention_mask)
            return attention_mask
    
(jdb) bt
Traceback:
  File "/home/yushijinhun/octo/octo-experiment/test.py", line 11
    actions = model.sample_actions(
  File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 187
    transformer_outputs = self.run_transformer(
  File "/home/yushijinhun/octo/octo/octo/model/octo_model.py", line 152
    return self.module.apply(
  File "/home/yushijinhun/octo/octo/octo/model/octo_module.py", line 249
    prefix_outputs, timestep_outputs = BlockTransformer(self.transformer_kwargs)(
  File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 172
    attention_mask = self.generate_attention_mask(prefix_groups, timestep_groups)
  File "/home/yushijinhun/octo/octo/octo/model/components/block_transformer.py", line 325
    jax.debug.breakpoint()
(jdb) tokens_per_prefix_group
[16]
(jdb) tokens_per_timestep_group
[256, 64, 1]
(jdb) horizon
2
(jdb) tokens_for_prefix
16
(jdb) tokens_per_time_step
321
(jdb) total_tokens
658
(jdb) get_token_metadata(657)    #### <--- Token 657 should belong to group "t=1 readout_action", NOT "t=1 obs_wrist"
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) get_token_metadata(656)
TokenMetadata(name='obs_wrist', timestep=1, attention_rules={'task_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>, 'obs_*': <AttentionRule.CAUSAL: 'other.timestep <= self.timestep'>})
(jdb) attention_mask[657, 656]
1
(jdb) attention_mask[656, 657]    #### <--- This should be FALSE, group "t=1 obs_wrist" should NOT attend to group "t=1 readout_action"
1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions