Skip to content

Commit 8c9645a

Browse files
committed
re-add RL model code
1 parent 936cd08 commit 8c9645a

File tree

5 files changed

+399
-1
lines changed

5 files changed

+399
-1
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
__version__ = "0.0.4"
88

99
from .modeling_utils import ModelMixin
10-
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
10+
from .models import AutoencoderKL, TemporalUNet, UNetConditionalModel, UNetUnconditionalModel, VQModel
1111
from .pipeline_utils import DiffusionPipeline
1212
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
1313
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
# limitations under the License.
1818

1919
from .unet_conditional import UNetConditionalModel
20+
from .unet_rl import TemporalUNet
2021
from .unet_unconditional import UNetUnconditionalModel
2122
from .vae import AutoencoderKL, VQModel

src/diffusers/models/resnet.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,70 @@
66
import torch.nn.functional as F
77

88

9+
class Upsample1D(nn.Module):
10+
"""
11+
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
12+
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
13+
If 3D, then
14+
upsampling occurs in the inner-two dimensions.
15+
"""
16+
17+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
18+
super().__init__()
19+
self.channels = channels
20+
self.out_channels = out_channels or channels
21+
self.use_conv = use_conv
22+
self.use_conv_transpose = use_conv_transpose
23+
self.name = name
24+
25+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
26+
self.conv = None
27+
if use_conv_transpose:
28+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
29+
elif use_conv:
30+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
31+
32+
def forward(self, x):
33+
assert x.shape[1] == self.channels
34+
if self.use_conv_transpose:
35+
return self.conv(x)
36+
37+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
38+
39+
if self.use_conv:
40+
x = self.conv(x)
41+
42+
return x
43+
44+
45+
class Downsample1D(nn.Module):
46+
"""
47+
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
48+
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
49+
If 3D, then
50+
downsampling occurs in the inner-two dimensions.
51+
"""
52+
53+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
54+
super().__init__()
55+
self.channels = channels
56+
self.out_channels = out_channels or channels
57+
self.use_conv = use_conv
58+
self.padding = padding
59+
stride = 2
60+
self.name = name
61+
62+
if use_conv:
63+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
64+
else:
65+
assert self.channels == self.out_channels
66+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
67+
68+
def forward(self, x):
69+
assert x.shape[1] == self.channels
70+
return self.conv(x)
71+
72+
973
class Upsample2D(nn.Module):
1074
"""
1175
An upsampling layer with an optional convolution.
@@ -763,6 +827,39 @@ def forward(self, tensor):
763827
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
764828

765829

830+
# unet_rl.py
831+
class ResidualTemporalBlock(nn.Module):
832+
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
833+
super().__init__()
834+
835+
self.blocks = nn.ModuleList(
836+
[
837+
Conv1dBlock(inp_channels, out_channels, kernel_size),
838+
Conv1dBlock(out_channels, out_channels, kernel_size),
839+
]
840+
)
841+
842+
self.time_mlp = nn.Sequential(
843+
nn.Mish(),
844+
nn.Linear(embed_dim, out_channels),
845+
RearrangeDim(),
846+
# Rearrange("batch t -> batch t 1"),
847+
)
848+
849+
self.residual_conv = (
850+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
851+
)
852+
853+
def forward(self, x, t):
854+
"""
855+
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
856+
out_channels x horizon ]
857+
"""
858+
out = self.blocks[0](x) + self.time_mlp(t)
859+
out = self.blocks[1](out)
860+
return out + self.residual_conv(x)
861+
862+
766863
def upsample_2d(x, k=None, factor=2, gain=1):
767864
r"""Upsample2D a batch of 2D images with the given filter.
768865

src/diffusers/models/unet_rl.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
7+
8+
from ..configuration_utils import ConfigMixin
9+
from ..modeling_utils import ModelMixin
10+
from .embeddings import get_timestep_embedding
11+
12+
13+
class SinusoidalPosEmb(nn.Module):
14+
def __init__(self, dim):
15+
super().__init__()
16+
self.dim = dim
17+
18+
def forward(self, x):
19+
return get_timestep_embedding(x, self.dim)
20+
21+
22+
class RearrangeDim(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
26+
def forward(self, tensor):
27+
if len(tensor.shape) == 2:
28+
return tensor[:, :, None]
29+
if len(tensor.shape) == 3:
30+
return tensor[:, :, None, :]
31+
elif len(tensor.shape) == 4:
32+
return tensor[:, :, 0, :]
33+
else:
34+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
35+
36+
37+
class Conv1dBlock(nn.Module):
38+
"""
39+
Conv1d --> GroupNorm --> Mish
40+
"""
41+
42+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
43+
super().__init__()
44+
45+
self.block = nn.Sequential(
46+
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
47+
RearrangeDim(),
48+
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
49+
nn.GroupNorm(n_groups, out_channels),
50+
RearrangeDim(),
51+
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
52+
nn.Mish(),
53+
)
54+
55+
def forward(self, x):
56+
return self.block(x)
57+
58+
59+
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
60+
def __init__(
61+
self,
62+
training_horizon=128,
63+
transition_dim=14,
64+
cond_dim=3,
65+
predict_epsilon=False,
66+
clip_denoised=True,
67+
dim=32,
68+
dim_mults=(1, 4, 8),
69+
):
70+
super().__init__()
71+
72+
self.transition_dim = transition_dim
73+
self.cond_dim = cond_dim
74+
self.predict_epsilon = predict_epsilon
75+
self.clip_denoised = clip_denoised
76+
77+
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
78+
in_out = list(zip(dims[:-1], dims[1:]))
79+
80+
time_dim = dim
81+
self.time_mlp = nn.Sequential(
82+
SinusoidalPosEmb(dim),
83+
nn.Linear(dim, dim * 4),
84+
nn.Mish(),
85+
nn.Linear(dim * 4, dim),
86+
)
87+
88+
self.downs = nn.ModuleList([])
89+
self.ups = nn.ModuleList([])
90+
num_resolutions = len(in_out)
91+
92+
for ind, (dim_in, dim_out) in enumerate(in_out):
93+
is_last = ind >= (num_resolutions - 1)
94+
95+
self.downs.append(
96+
nn.ModuleList(
97+
[
98+
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
99+
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
100+
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
101+
]
102+
)
103+
)
104+
105+
if not is_last:
106+
training_horizon = training_horizon // 2
107+
108+
mid_dim = dims[-1]
109+
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
110+
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
111+
112+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
113+
is_last = ind >= (num_resolutions - 1)
114+
115+
self.ups.append(
116+
nn.ModuleList(
117+
[
118+
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
119+
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
120+
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
121+
]
122+
)
123+
)
124+
125+
if not is_last:
126+
training_horizon = training_horizon * 2
127+
128+
self.final_conv = nn.Sequential(
129+
Conv1dBlock(dim, dim, kernel_size=5),
130+
nn.Conv1d(dim, transition_dim, 1),
131+
)
132+
133+
def forward(self, sample, timesteps):
134+
"""
135+
x : [ batch x horizon x transition ]
136+
"""
137+
x = sample
138+
139+
x = x.permute(0, 2, 1)
140+
141+
t = self.time_mlp(timesteps)
142+
h = []
143+
144+
for resnet, resnet2, downsample in self.downs:
145+
x = resnet(x, t)
146+
x = resnet2(x, t)
147+
h.append(x)
148+
x = downsample(x)
149+
150+
x = self.mid_block1(x, t)
151+
x = self.mid_block2(x, t)
152+
153+
for resnet, resnet2, upsample in self.ups:
154+
x = torch.cat((x, h.pop()), dim=1)
155+
x = resnet(x, t)
156+
x = resnet2(x, t)
157+
x = upsample(x)
158+
159+
x = self.final_conv(x)
160+
161+
x = x.permute(0, 2, 1)
162+
return x
163+
164+
165+
class TemporalValue(nn.Module):
166+
def __init__(
167+
self,
168+
horizon,
169+
transition_dim,
170+
cond_dim,
171+
dim=32,
172+
time_dim=None,
173+
out_dim=1,
174+
dim_mults=(1, 2, 4, 8),
175+
):
176+
super().__init__()
177+
178+
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
179+
in_out = list(zip(dims[:-1], dims[1:]))
180+
181+
time_dim = time_dim or dim
182+
self.time_mlp = nn.Sequential(
183+
SinusoidalPosEmb(dim),
184+
nn.Linear(dim, dim * 4),
185+
nn.Mish(),
186+
nn.Linear(dim * 4, dim),
187+
)
188+
189+
self.blocks = nn.ModuleList([])
190+
191+
print(in_out)
192+
for dim_in, dim_out in in_out:
193+
self.blocks.append(
194+
nn.ModuleList(
195+
[
196+
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
197+
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
198+
Downsample1d(dim_out),
199+
]
200+
)
201+
)
202+
203+
horizon = horizon // 2
204+
205+
fc_dim = dims[-1] * max(horizon, 1)
206+
207+
self.final_block = nn.Sequential(
208+
nn.Linear(fc_dim + time_dim, fc_dim // 2),
209+
nn.Mish(),
210+
nn.Linear(fc_dim // 2, out_dim),
211+
)
212+
213+
def forward(self, x, cond, time, *args):
214+
"""
215+
x : [ batch x horizon x transition ]
216+
"""
217+
x = x.permute(0, 2, 1)
218+
219+
t = self.time_mlp(time)
220+
221+
for resnet, resnet2, downsample in self.blocks:
222+
x = resnet(x, t)
223+
x = resnet2(x, t)
224+
x = downsample(x)
225+
226+
x = x.view(len(x), -1)
227+
out = self.final_block(torch.cat([x, t], dim=-1))
228+
return out

0 commit comments

Comments
 (0)