|
| 1 | +# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D |
| 7 | + |
| 8 | +from ..configuration_utils import ConfigMixin |
| 9 | +from ..modeling_utils import ModelMixin |
| 10 | +from .embeddings import get_timestep_embedding |
| 11 | + |
| 12 | + |
| 13 | +class SinusoidalPosEmb(nn.Module): |
| 14 | + def __init__(self, dim): |
| 15 | + super().__init__() |
| 16 | + self.dim = dim |
| 17 | + |
| 18 | + def forward(self, x): |
| 19 | + return get_timestep_embedding(x, self.dim) |
| 20 | + |
| 21 | + |
| 22 | +class RearrangeDim(nn.Module): |
| 23 | + def __init__(self): |
| 24 | + super().__init__() |
| 25 | + |
| 26 | + def forward(self, tensor): |
| 27 | + if len(tensor.shape) == 2: |
| 28 | + return tensor[:, :, None] |
| 29 | + if len(tensor.shape) == 3: |
| 30 | + return tensor[:, :, None, :] |
| 31 | + elif len(tensor.shape) == 4: |
| 32 | + return tensor[:, :, 0, :] |
| 33 | + else: |
| 34 | + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") |
| 35 | + |
| 36 | + |
| 37 | +class Conv1dBlock(nn.Module): |
| 38 | + """ |
| 39 | + Conv1d --> GroupNorm --> Mish |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): |
| 43 | + super().__init__() |
| 44 | + |
| 45 | + self.block = nn.Sequential( |
| 46 | + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), |
| 47 | + RearrangeDim(), |
| 48 | + # Rearrange("batch channels horizon -> batch channels 1 horizon"), |
| 49 | + nn.GroupNorm(n_groups, out_channels), |
| 50 | + RearrangeDim(), |
| 51 | + # Rearrange("batch channels 1 horizon -> batch channels horizon"), |
| 52 | + nn.Mish(), |
| 53 | + ) |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + return self.block(x) |
| 57 | + |
| 58 | + |
| 59 | +class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + training_horizon=128, |
| 63 | + transition_dim=14, |
| 64 | + cond_dim=3, |
| 65 | + predict_epsilon=False, |
| 66 | + clip_denoised=True, |
| 67 | + dim=32, |
| 68 | + dim_mults=(1, 4, 8), |
| 69 | + ): |
| 70 | + super().__init__() |
| 71 | + |
| 72 | + self.transition_dim = transition_dim |
| 73 | + self.cond_dim = cond_dim |
| 74 | + self.predict_epsilon = predict_epsilon |
| 75 | + self.clip_denoised = clip_denoised |
| 76 | + |
| 77 | + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] |
| 78 | + in_out = list(zip(dims[:-1], dims[1:])) |
| 79 | + |
| 80 | + time_dim = dim |
| 81 | + self.time_mlp = nn.Sequential( |
| 82 | + SinusoidalPosEmb(dim), |
| 83 | + nn.Linear(dim, dim * 4), |
| 84 | + nn.Mish(), |
| 85 | + nn.Linear(dim * 4, dim), |
| 86 | + ) |
| 87 | + |
| 88 | + self.downs = nn.ModuleList([]) |
| 89 | + self.ups = nn.ModuleList([]) |
| 90 | + num_resolutions = len(in_out) |
| 91 | + |
| 92 | + for ind, (dim_in, dim_out) in enumerate(in_out): |
| 93 | + is_last = ind >= (num_resolutions - 1) |
| 94 | + |
| 95 | + self.downs.append( |
| 96 | + nn.ModuleList( |
| 97 | + [ |
| 98 | + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), |
| 99 | + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), |
| 100 | + Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), |
| 101 | + ] |
| 102 | + ) |
| 103 | + ) |
| 104 | + |
| 105 | + if not is_last: |
| 106 | + training_horizon = training_horizon // 2 |
| 107 | + |
| 108 | + mid_dim = dims[-1] |
| 109 | + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) |
| 110 | + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) |
| 111 | + |
| 112 | + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): |
| 113 | + is_last = ind >= (num_resolutions - 1) |
| 114 | + |
| 115 | + self.ups.append( |
| 116 | + nn.ModuleList( |
| 117 | + [ |
| 118 | + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), |
| 119 | + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), |
| 120 | + Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), |
| 121 | + ] |
| 122 | + ) |
| 123 | + ) |
| 124 | + |
| 125 | + if not is_last: |
| 126 | + training_horizon = training_horizon * 2 |
| 127 | + |
| 128 | + self.final_conv = nn.Sequential( |
| 129 | + Conv1dBlock(dim, dim, kernel_size=5), |
| 130 | + nn.Conv1d(dim, transition_dim, 1), |
| 131 | + ) |
| 132 | + |
| 133 | + def forward(self, sample, timesteps): |
| 134 | + """ |
| 135 | + x : [ batch x horizon x transition ] |
| 136 | + """ |
| 137 | + x = sample |
| 138 | + |
| 139 | + x = x.permute(0, 2, 1) |
| 140 | + |
| 141 | + t = self.time_mlp(timesteps) |
| 142 | + h = [] |
| 143 | + |
| 144 | + for resnet, resnet2, downsample in self.downs: |
| 145 | + x = resnet(x, t) |
| 146 | + x = resnet2(x, t) |
| 147 | + h.append(x) |
| 148 | + x = downsample(x) |
| 149 | + |
| 150 | + x = self.mid_block1(x, t) |
| 151 | + x = self.mid_block2(x, t) |
| 152 | + |
| 153 | + for resnet, resnet2, upsample in self.ups: |
| 154 | + x = torch.cat((x, h.pop()), dim=1) |
| 155 | + x = resnet(x, t) |
| 156 | + x = resnet2(x, t) |
| 157 | + x = upsample(x) |
| 158 | + |
| 159 | + x = self.final_conv(x) |
| 160 | + |
| 161 | + x = x.permute(0, 2, 1) |
| 162 | + return x |
| 163 | + |
| 164 | + |
| 165 | +class TemporalValue(nn.Module): |
| 166 | + def __init__( |
| 167 | + self, |
| 168 | + horizon, |
| 169 | + transition_dim, |
| 170 | + cond_dim, |
| 171 | + dim=32, |
| 172 | + time_dim=None, |
| 173 | + out_dim=1, |
| 174 | + dim_mults=(1, 2, 4, 8), |
| 175 | + ): |
| 176 | + super().__init__() |
| 177 | + |
| 178 | + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] |
| 179 | + in_out = list(zip(dims[:-1], dims[1:])) |
| 180 | + |
| 181 | + time_dim = time_dim or dim |
| 182 | + self.time_mlp = nn.Sequential( |
| 183 | + SinusoidalPosEmb(dim), |
| 184 | + nn.Linear(dim, dim * 4), |
| 185 | + nn.Mish(), |
| 186 | + nn.Linear(dim * 4, dim), |
| 187 | + ) |
| 188 | + |
| 189 | + self.blocks = nn.ModuleList([]) |
| 190 | + |
| 191 | + print(in_out) |
| 192 | + for dim_in, dim_out in in_out: |
| 193 | + self.blocks.append( |
| 194 | + nn.ModuleList( |
| 195 | + [ |
| 196 | + ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), |
| 197 | + ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), |
| 198 | + Downsample1d(dim_out), |
| 199 | + ] |
| 200 | + ) |
| 201 | + ) |
| 202 | + |
| 203 | + horizon = horizon // 2 |
| 204 | + |
| 205 | + fc_dim = dims[-1] * max(horizon, 1) |
| 206 | + |
| 207 | + self.final_block = nn.Sequential( |
| 208 | + nn.Linear(fc_dim + time_dim, fc_dim // 2), |
| 209 | + nn.Mish(), |
| 210 | + nn.Linear(fc_dim // 2, out_dim), |
| 211 | + ) |
| 212 | + |
| 213 | + def forward(self, x, cond, time, *args): |
| 214 | + """ |
| 215 | + x : [ batch x horizon x transition ] |
| 216 | + """ |
| 217 | + x = x.permute(0, 2, 1) |
| 218 | + |
| 219 | + t = self.time_mlp(time) |
| 220 | + |
| 221 | + for resnet, resnet2, downsample in self.blocks: |
| 222 | + x = resnet(x, t) |
| 223 | + x = resnet2(x, t) |
| 224 | + x = downsample(x) |
| 225 | + |
| 226 | + x = x.view(len(x), -1) |
| 227 | + out = self.final_block(torch.cat([x, t], dim=-1)) |
| 228 | + return out |
0 commit comments