File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 55
66from diffusers .models .resnet import Downsample1D , ResidualTemporalBlock , Upsample1D
77
8- from ..configuration_utils import ConfigMixin
8+ from ..configuration_utils import ConfigMixin , register_to_config
99from ..modeling_utils import ModelMixin
1010from .embeddings import get_timestep_embedding
1111
@@ -57,6 +57,7 @@ def forward(self, x):
5757
5858
5959class TemporalUNet (ModelMixin , ConfigMixin ): # (nn.Module):
60+ @register_to_config
6061 def __init__ (
6162 self ,
6263 training_horizon = 128 ,
Original file line number Diff line number Diff line change @@ -595,6 +595,12 @@ def input_shape(self):
595595 def output_shape (self ):
596596 return (4 , 16 , 14 )
597597
598+ def test_ema_training (self ):
599+ pass
600+
601+ def test_training (self ):
602+ pass
603+
598604 def prepare_init_args_and_inputs_for_common (self ):
599605 init_dict = {
600606 "training_horizon" : 128 ,
You can’t perform that action at this time.
0 commit comments