Skip to content

Commit ec4f665

Browse files
pcuencaPrathik Rao
authored andcommitted
Flax: add shape argument to set_timesteps (huggingface#690)
* Flax: add shape argument to set_timesteps * style
1 parent a75419a commit ec4f665

File tree

6 files changed

+10
-6
lines changed

6 files changed

+10
-6
lines changed

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
156156

157157
return variance
158158

159-
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
159+
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
160160
"""
161161
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
162162

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133

134134
self.variance_type = variance_type
135135

136-
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
136+
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
137137
"""
138138
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
139139

src/diffusers/schedulers/scheduling_karras_ve_flax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def __init__(
9999
):
100100
self.state = KarrasVeSchedulerState.create()
101101

102-
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
102+
def set_timesteps(
103+
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
104+
) -> KarrasVeSchedulerState:
103105
"""
104106
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
105107

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def lms_derivative(tau):
111111

112112
return integrated_coeff
113113

114-
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
114+
def set_timesteps(
115+
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
116+
) -> LMSDiscreteSchedulerState:
115117
"""
116118
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
117119

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156
def create_state(self):
157157
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
158158

159-
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
159+
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
160160
"""
161161
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
162162

src/diffusers/schedulers/scheduling_sde_ve_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
9696

9797
def set_timesteps(
98-
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
98+
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
9999
) -> ScoreSdeVeSchedulerState:
100100
"""
101101
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

0 commit comments

Comments
 (0)