Skip to content

Commit 3c91e7c

Browse files
committed
Merge branch 'refs/heads/semantic_flux' into semantic_dits
# Conflicts: # src/diffusers/__init__.py
2 parents 09eed8f + 0deed2e commit 3c91e7c

File tree

5 files changed

+1061
-0
lines changed

5 files changed

+1061
-0
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@
300300
"PixArtAlphaPipeline",
301301
"PixArtSigmaPAGPipeline",
302302
"PixArtSigmaPipeline",
303+
"SemanticFluxPipeline",
303304
"SemanticHunyuanDiTPipeline",
304305
"SemanticStableDiffusionPipeline",
305306
"ShapEImg2ImgPipeline",
@@ -740,6 +741,7 @@
740741
PixArtAlphaPipeline,
741742
PixArtSigmaPAGPipeline,
742743
PixArtSigmaPipeline,
744+
SemanticFluxPipeline,
743745
SemanticHunyuanDiTPipeline,
744746
SemanticStableDiffusionPipeline,
745747
ShapEImg2ImgPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@
234234
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
235235
_import_structure["semantic_hunyuandit"] = ["SemanticHunyuanDiTPipeline"]
236236
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
237+
_import_structure["semantic_flux"] = ["SemanticFluxPipeline"]
237238
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
238239
_import_structure["stable_audio"] = [
239240
"StableAudioProjectionModel",
@@ -559,6 +560,7 @@
559560
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
560561
from .semantic_hunyuandit import SemanticHunyuanDiTPipeline
561562
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
563+
from .semantic_flux import SemanticFluxPipeline
562564
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
563565
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
564566
from .stable_cascade import (
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["pipeline_output"] = ["SemanticFluxPipelineOutput"]
25+
_import_structure["pipeline_semantic_flux"] = ["SemanticFluxPipeline"]
26+
27+
28+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29+
try:
30+
if not (is_transformers_available() and is_torch_available()):
31+
raise OptionalDependencyNotAvailable()
32+
33+
except OptionalDependencyNotAvailable:
34+
from ...utils.dummy_torch_and_transformers_objects import *
35+
else:
36+
from .pipeline_semantic_flux import SemanticFluxPipeline
37+
38+
else:
39+
import sys
40+
41+
sys.modules[__name__] = _LazyModule(
42+
__name__,
43+
globals()["__file__"],
44+
_import_structure,
45+
module_spec=__spec__,
46+
)
47+
48+
for name, value in _dummy_objects.items():
49+
setattr(sys.modules[__name__], name, value)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
import PIL.Image
6+
7+
from ...utils import BaseOutput
8+
9+
10+
@dataclass
11+
class SemanticFluxPipelineOutput(BaseOutput):
12+
"""
13+
Output class for Flux pipelines.
14+
15+
Args:
16+
images (`List[PIL.Image.Image]` or `np.ndarray`)
17+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
18+
num_channels)`.
19+
nsfw_content_detected (`List[bool]`)
20+
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
21+
`None` if safety checking could not be performed.
22+
"""
23+
24+
images: Union[List[PIL.Image.Image], np.ndarray]
25+
nsfw_content_detected: Optional[List[bool]]

0 commit comments

Comments
 (0)