Skip to content

Commit a45dca0

Browse files
authored
Fix BaseOutput initialization from dict (#570)
* Fix BaseOutput initialization from dict * style * Simplify post-init, add tests * remove debug
1 parent c01ec2d commit a45dca0

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/diffusers/utils/outputs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,17 @@ def __post_init__(self):
5959
if not len(class_fields):
6060
raise ValueError(f"{self.__class__.__name__} has no fields.")
6161

62-
for field in class_fields:
63-
v = getattr(self, field.name)
64-
if v is not None:
65-
self[field.name] = v
62+
first_field = getattr(self, class_fields[0].name)
63+
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
64+
65+
if other_fields_are_none and isinstance(first_field, dict):
66+
for key, value in first_field.items():
67+
self[key] = value
68+
else:
69+
for field in class_fields:
70+
v = getattr(self, field.name)
71+
if v is not None:
72+
self[field.name] = v
6673

6774
def __delitem__(self, *args, **kwargs):
6875
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")

tests/test_outputs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
from dataclasses import dataclass
3+
from typing import List, Union
4+
5+
import numpy as np
6+
7+
import PIL.Image
8+
from diffusers.utils.outputs import BaseOutput
9+
10+
11+
@dataclass
12+
class CustomOutput(BaseOutput):
13+
images: Union[List[PIL.Image.Image], np.ndarray]
14+
15+
16+
class ConfigTester(unittest.TestCase):
17+
def test_outputs_single_attribute(self):
18+
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
19+
20+
# check every way of getting the attribute
21+
assert isinstance(outputs.images, np.ndarray)
22+
assert outputs.images.shape == (1, 3, 4, 4)
23+
assert isinstance(outputs["images"], np.ndarray)
24+
assert outputs["images"].shape == (1, 3, 4, 4)
25+
assert isinstance(outputs[0], np.ndarray)
26+
assert outputs[0].shape == (1, 3, 4, 4)
27+
28+
# test with a non-tensor attribute
29+
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
30+
31+
# check every way of getting the attribute
32+
assert isinstance(outputs.images, list)
33+
assert isinstance(outputs.images[0], PIL.Image.Image)
34+
assert isinstance(outputs["images"], list)
35+
assert isinstance(outputs["images"][0], PIL.Image.Image)
36+
assert isinstance(outputs[0], list)
37+
assert isinstance(outputs[0][0], PIL.Image.Image)
38+
39+
def test_outputs_dict_init(self):
40+
# test output reinitialization with a `dict` for compatibility with `accelerate`
41+
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
42+
43+
# check every way of getting the attribute
44+
assert isinstance(outputs.images, np.ndarray)
45+
assert outputs.images.shape == (1, 3, 4, 4)
46+
assert isinstance(outputs["images"], np.ndarray)
47+
assert outputs["images"].shape == (1, 3, 4, 4)
48+
assert isinstance(outputs[0], np.ndarray)
49+
assert outputs[0].shape == (1, 3, 4, 4)
50+
51+
# test with a non-tensor attribute
52+
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
53+
54+
# check every way of getting the attribute
55+
assert isinstance(outputs.images, list)
56+
assert isinstance(outputs.images[0], PIL.Image.Image)
57+
assert isinstance(outputs["images"], list)
58+
assert isinstance(outputs["images"][0], PIL.Image.Image)
59+
assert isinstance(outputs[0], list)
60+
assert isinstance(outputs[0][0], PIL.Image.Image)

0 commit comments

Comments
 (0)