File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -698,11 +698,13 @@ def _test_xformers_attention_forwardGenerator_pass(
698698 pipe .set_progress_bar_config (disable = None )
699699
700700 inputs = self .get_dummy_inputs (torch_device )
701- output_without_offload = pipe (** inputs )[0 ].cpu ()
701+ output_without_offload = pipe (** inputs )[0 ]
702+ output_without_offload .cpu () if torch .is_tensor (output_without_offload ) else output_without_offload
702703
703704 pipe .enable_xformers_memory_efficient_attention ()
704705 inputs = self .get_dummy_inputs (torch_device )
705- output_with_offload = pipe (** inputs )[0 ].cpu ()
706+ output_with_offload = pipe (** inputs )[0 ]
707+ output_with_offload .cpu () if torch .is_tensor (output_with_offload ) else output_without_offload
706708
707709 if test_max_difference :
708710 max_diff = np .abs (output_with_offload - output_without_offload ).max ()
You can’t perform that action at this time.
0 commit comments