Skip to content

Conversation

@patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Dec 3, 2022

Inspired by #1493, this PR makes set_attention_slice method recursive to be able to easily apply it to various attention blocks without much boilerplate.

  • Make set_attention_slice a method of DiffusionPipeline, so it can be applied to all pipelines, since all of them have attention.
  • Handle the auto slice logic inside UNet2DConditionModel and recurse there to apply the slice to all blocks.

The logic differs a bit from #1493, in that, the we keep the set_attention_slice method on main model classes like UNet2DConditionModel and move the recursive application logic there. This is so that we can control this functionality using individual model classes.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@patil-suraj patil-suraj changed the title [wip] make attn slice recursive [refactor] make set_attention_slice recursive Dec 5, 2022
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not change community pipelines. This will remove functionality at the moment.

)

def _set_attention_slice(self, slice_size):
def set_attention_slice(self, slice_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! works for me

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this PR! Let's just remove the changes from the community pipelines.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Same comment as Patrick, this would currently break community pipelines as they are loaded from main.

Copy link
Contributor Author

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Just left one comment

Comment on lines +260 to +267
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this logic could be moved to the attention block itself. That way if would be easy to support this functionality in new models which use the attention blocks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. As discussed offline, I'll leave as is now though so that we can test that number of passed slices equals number of slicable layers. Otherwise we cannot really test it.

Comment on lines +338 to +346
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]

slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

Comment on lines +337 to +374
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

model.set_attention_slice("auto")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None

model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None

model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None

def test_model_slicable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

model = self.model_class(**init_dict)

def check_slicable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)

for child in module.children():
check_slicable_dim_attr(child)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the tests!

@patrickvonplaten patrickvonplaten merged commit bce65cd into main Dec 5, 2022
@patrickvonplaten patrickvonplaten deleted the recirsive-attn-slice branch December 5, 2022 16:31
tcapelle pushed a commit to tcapelle/diffusers that referenced this pull request Dec 12, 2022
* make attn slice recursive

* remove set_attention_slice from blocks

* fix copies

* make enable_attention_slicing base class method of DiffusionPipeline

* fix set_attention_slice

* fix set_attention_slice

* fix copies

* add tests

* up

* up

* up

* update

* up

* uP

Co-authored-by: Patrick von Platen <[email protected]>
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* make attn slice recursive

* remove set_attention_slice from blocks

* fix copies

* make enable_attention_slicing base class method of DiffusionPipeline

* fix set_attention_slice

* fix set_attention_slice

* fix copies

* add tests

* up

* up

* up

* update

* up

* uP

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* make attn slice recursive

* remove set_attention_slice from blocks

* fix copies

* make enable_attention_slicing base class method of DiffusionPipeline

* fix set_attention_slice

* fix set_attention_slice

* fix copies

* add tests

* up

* up

* up

* update

* up

* uP

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants