@@ -47,7 +47,7 @@ def __call__(
4747 batch_size : int = 1 ,
4848 num_inference_steps : int = 100 ,
4949 generator : Optional [torch .Generator ] = None ,
50- sample_length_in_s : Optional [float ] = None ,
50+ audio_length_in_s : Optional [float ] = None ,
5151 return_dict : bool = True ,
5252 ) -> Union [AudioPipelineOutput , Tuple ]:
5353 r"""
@@ -60,6 +60,9 @@ def __call__(
6060 generator (`torch.Generator`, *optional*):
6161 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
6262 deterministic.
63+ audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
64+ The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
65+ `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
6366 return_dict (`bool`, *optional*, defaults to `True`):
6467 Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
6568
@@ -69,23 +72,23 @@ def __call__(
6972 generated images.
7073 """
7174
72- if sample_length_in_s is None :
73- sample_length_in_s = self .unet .sample_size / self .unet .sample_rate
75+ if audio_length_in_s is None :
76+ audio_length_in_s = self .unet .config . sample_size / self .unet . config .sample_rate
7477
75- sample_size = sample_length_in_s * self .unet .sample_rate
78+ sample_size = audio_length_in_s * self .unet .sample_rate
7679
7780 down_scale_factor = 2 ** len (self .unet .up_blocks )
7881 if sample_size < 3 * down_scale_factor :
7982 raise ValueError (
80- f"{ sample_length_in_s } is too small. Make sure it's bigger or equal to"
83+ f"{ audio_length_in_s } is too small. Make sure it's bigger or equal to"
8184 f" { 3 * down_scale_factor / self .unet .sample_rate } ."
8285 )
8386
8487 original_sample_size = int (sample_size )
8588 if sample_size % down_scale_factor != 0 :
86- sample_size = ((sample_length_in_s * self .unet .sample_rate ) // down_scale_factor + 1 ) * down_scale_factor
89+ sample_size = ((audio_length_in_s * self .unet .sample_rate ) // down_scale_factor + 1 ) * down_scale_factor
8790 logger .info (
88- f"{ sample_length_in_s } is increased to { sample_size / self .unet .sample_rate } so that it can be handled"
91+ f"{ audio_length_in_s } is increased to { sample_size / self .unet .sample_rate } so that it can be handled"
8992 f" by the model. It will be cut to { original_sample_size / self .unet .sample_rate } after the denoising"
9093 " process."
9194 )
0 commit comments