Skip to content

Commit 0deed2e

Browse files
committed
implement semantic flux
1 parent 98930ee commit 0deed2e

File tree

5 files changed

+1061
-0
lines changed

5 files changed

+1061
-0
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@
300300
"PixArtAlphaPipeline",
301301
"PixArtSigmaPAGPipeline",
302302
"PixArtSigmaPipeline",
303+
"SemanticFluxPipeline",
303304
"SemanticStableDiffusionPipeline",
304305
"ShapEImg2ImgPipeline",
305306
"ShapEPipeline",
@@ -739,6 +740,7 @@
739740
PixArtAlphaPipeline,
740741
PixArtSigmaPAGPipeline,
741742
PixArtSigmaPipeline,
743+
SemanticFluxPipeline,
742744
SemanticStableDiffusionPipeline,
743745
ShapEImg2ImgPipeline,
744746
ShapEPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
_import_structure["pia"] = ["PIAPipeline"]
234234
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
235235
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
236+
_import_structure["semantic_flux"] = ["SemanticFluxPipeline"]
236237
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
237238
_import_structure["stable_audio"] = [
238239
"StableAudioProjectionModel",
@@ -557,6 +558,7 @@
557558
from .pia import PIAPipeline
558559
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
559560
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
561+
from .semantic_flux import SemanticFluxPipeline
560562
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
561563
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
562564
from .stable_cascade import (
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["pipeline_output"] = ["SemanticFluxPipelineOutput"]
25+
_import_structure["pipeline_semantic_flux"] = ["SemanticFluxPipeline"]
26+
27+
28+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29+
try:
30+
if not (is_transformers_available() and is_torch_available()):
31+
raise OptionalDependencyNotAvailable()
32+
33+
except OptionalDependencyNotAvailable:
34+
from ...utils.dummy_torch_and_transformers_objects import *
35+
else:
36+
from .pipeline_semantic_flux import SemanticFluxPipeline
37+
38+
else:
39+
import sys
40+
41+
sys.modules[__name__] = _LazyModule(
42+
__name__,
43+
globals()["__file__"],
44+
_import_structure,
45+
module_spec=__spec__,
46+
)
47+
48+
for name, value in _dummy_objects.items():
49+
setattr(sys.modules[__name__], name, value)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
import PIL.Image
6+
7+
from ...utils import BaseOutput
8+
9+
10+
@dataclass
11+
class SemanticFluxPipelineOutput(BaseOutput):
12+
"""
13+
Output class for Flux pipelines.
14+
15+
Args:
16+
images (`List[PIL.Image.Image]` or `np.ndarray`)
17+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
18+
num_channels)`.
19+
nsfw_content_detected (`List[bool]`)
20+
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
21+
`None` if safety checking could not be performed.
22+
"""
23+
24+
images: Union[List[PIL.Image.Image], np.ndarray]
25+
nsfw_content_detected: Optional[List[bool]]

0 commit comments

Comments
 (0)