File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 55
66from diffusers .models .resnet import Downsample1D , ResidualTemporalBlock , Upsample1D
77
8- from ..configuration_utils import ConfigMixin
8+ from ..configuration_utils import ConfigMixin , register_to_config
99from ..modeling_utils import ModelMixin
1010from .embeddings import get_timestep_embedding
1111
@@ -57,6 +57,7 @@ def forward(self, x):
5757
5858
5959class TemporalUNet (ModelMixin , ConfigMixin ): # (nn.Module):
60+ @register_to_config
6061 def __init__ (
6162 self ,
6263 training_horizon = 128 ,
Original file line number Diff line number Diff line change @@ -629,6 +629,12 @@ def input_shape(self):
629629 def output_shape (self ):
630630 return (4 , 16 , 14 )
631631
632+ def test_ema_training (self ):
633+ pass
634+
635+ def test_training (self ):
636+ pass
637+
632638 def prepare_init_args_and_inputs_for_common (self ):
633639 init_dict = {
634640 "training_horizon" : 128 ,
You can’t perform that action at this time.
0 commit comments