Skip to content

Commit 1b6b68c

Browse files
[Dance Diffusion] Better naming (huggingface#981)
uP
1 parent 2db92f8 commit 1b6b68c

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __call__(
4747
batch_size: int = 1,
4848
num_inference_steps: int = 100,
4949
generator: Optional[torch.Generator] = None,
50-
sample_length_in_s: Optional[float] = None,
50+
audio_length_in_s: Optional[float] = None,
5151
return_dict: bool = True,
5252
) -> Union[AudioPipelineOutput, Tuple]:
5353
r"""
@@ -60,6 +60,9 @@ def __call__(
6060
generator (`torch.Generator`, *optional*):
6161
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
6262
deterministic.
63+
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
64+
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
65+
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
6366
return_dict (`bool`, *optional*, defaults to `True`):
6467
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
6568
@@ -69,23 +72,23 @@ def __call__(
6972
generated images.
7073
"""
7174

72-
if sample_length_in_s is None:
73-
sample_length_in_s = self.unet.sample_size / self.unet.sample_rate
75+
if audio_length_in_s is None:
76+
audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
7477

75-
sample_size = sample_length_in_s * self.unet.sample_rate
78+
sample_size = audio_length_in_s * self.unet.sample_rate
7679

7780
down_scale_factor = 2 ** len(self.unet.up_blocks)
7881
if sample_size < 3 * down_scale_factor:
7982
raise ValueError(
80-
f"{sample_length_in_s} is too small. Make sure it's bigger or equal to"
83+
f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
8184
f" {3 * down_scale_factor / self.unet.sample_rate}."
8285
)
8386

8487
original_sample_size = int(sample_size)
8588
if sample_size % down_scale_factor != 0:
86-
sample_size = ((sample_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
89+
sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
8790
logger.info(
88-
f"{sample_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
91+
f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
8992
f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
9093
" process."
9194
)

0 commit comments

Comments
 (0)