|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | + |
| 4 | +from torchaudio_unittest import common_utils |
| 5 | +from torchaudio.prototype.transducer import RNNTLoss |
| 6 | + |
| 7 | + |
| 8 | +def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): |
| 9 | + logits = np.array( |
| 10 | + [ |
| 11 | + 0.065357, |
| 12 | + 0.787530, |
| 13 | + 0.081592, |
| 14 | + 0.529716, |
| 15 | + 0.750675, |
| 16 | + 0.754135, |
| 17 | + 0.609764, |
| 18 | + 0.868140, |
| 19 | + 0.622532, |
| 20 | + 0.668522, |
| 21 | + 0.858039, |
| 22 | + 0.164539, |
| 23 | + 0.989780, |
| 24 | + 0.944298, |
| 25 | + 0.603168, |
| 26 | + 0.946783, |
| 27 | + 0.666203, |
| 28 | + 0.286882, |
| 29 | + 0.094184, |
| 30 | + 0.366674, |
| 31 | + 0.736168, |
| 32 | + 0.166680, |
| 33 | + 0.714154, |
| 34 | + 0.399400, |
| 35 | + 0.535982, |
| 36 | + 0.291821, |
| 37 | + 0.612642, |
| 38 | + 0.324241, |
| 39 | + 0.800764, |
| 40 | + 0.524106, |
| 41 | + 0.779195, |
| 42 | + 0.183314, |
| 43 | + 0.113745, |
| 44 | + 0.240222, |
| 45 | + 0.339470, |
| 46 | + 0.134160, |
| 47 | + 0.505562, |
| 48 | + 0.051597, |
| 49 | + 0.640290, |
| 50 | + 0.430733, |
| 51 | + 0.829473, |
| 52 | + 0.177467, |
| 53 | + 0.320700, |
| 54 | + 0.042883, |
| 55 | + 0.302803, |
| 56 | + 0.675178, |
| 57 | + 0.569537, |
| 58 | + 0.558474, |
| 59 | + 0.083132, |
| 60 | + 0.060165, |
| 61 | + 0.107958, |
| 62 | + 0.748615, |
| 63 | + 0.943918, |
| 64 | + 0.486356, |
| 65 | + 0.418199, |
| 66 | + 0.652408, |
| 67 | + 0.024243, |
| 68 | + 0.134582, |
| 69 | + 0.366342, |
| 70 | + 0.295830, |
| 71 | + 0.923670, |
| 72 | + 0.689929, |
| 73 | + 0.741898, |
| 74 | + 0.250005, |
| 75 | + 0.603430, |
| 76 | + 0.987289, |
| 77 | + 0.592606, |
| 78 | + 0.884672, |
| 79 | + 0.543450, |
| 80 | + 0.660770, |
| 81 | + 0.377128, |
| 82 | + 0.358021, |
| 83 | + ], |
| 84 | + dtype=dtype, |
| 85 | + ).reshape(2, 4, 3, 3) |
| 86 | + |
| 87 | + targets = np.array([[1, 2], [1, 1]], dtype=np.int32) |
| 88 | + src_lengths = np.array([4, 4], dtype=np.int32) |
| 89 | + tgt_lengths = np.array([2, 2], dtype=np.int32) |
| 90 | + |
| 91 | + blank = 0 |
| 92 | + |
| 93 | + ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype) |
| 94 | + |
| 95 | + ref_gradients = np.array( |
| 96 | + [ |
| 97 | + -0.186844, |
| 98 | + -0.062555, |
| 99 | + 0.249399, |
| 100 | + -0.203377, |
| 101 | + 0.202399, |
| 102 | + 0.000977, |
| 103 | + -0.141016, |
| 104 | + 0.079123, |
| 105 | + 0.061893, |
| 106 | + -0.011552, |
| 107 | + -0.081280, |
| 108 | + 0.092832, |
| 109 | + -0.154257, |
| 110 | + 0.229433, |
| 111 | + -0.075176, |
| 112 | + -0.246593, |
| 113 | + 0.146405, |
| 114 | + 0.100188, |
| 115 | + -0.012918, |
| 116 | + -0.061593, |
| 117 | + 0.074512, |
| 118 | + -0.055986, |
| 119 | + 0.219831, |
| 120 | + -0.163845, |
| 121 | + -0.497627, |
| 122 | + 0.209240, |
| 123 | + 0.288387, |
| 124 | + 0.013605, |
| 125 | + -0.030220, |
| 126 | + 0.016615, |
| 127 | + 0.113925, |
| 128 | + 0.062781, |
| 129 | + -0.176706, |
| 130 | + -0.667078, |
| 131 | + 0.367659, |
| 132 | + 0.299419, |
| 133 | + -0.356344, |
| 134 | + -0.055347, |
| 135 | + 0.411691, |
| 136 | + -0.096922, |
| 137 | + 0.029459, |
| 138 | + 0.067463, |
| 139 | + -0.063518, |
| 140 | + 0.027654, |
| 141 | + 0.035863, |
| 142 | + -0.154499, |
| 143 | + -0.073942, |
| 144 | + 0.228441, |
| 145 | + -0.166790, |
| 146 | + -0.000088, |
| 147 | + 0.166878, |
| 148 | + -0.172370, |
| 149 | + 0.105565, |
| 150 | + 0.066804, |
| 151 | + 0.023875, |
| 152 | + -0.118256, |
| 153 | + 0.094381, |
| 154 | + -0.104707, |
| 155 | + -0.108934, |
| 156 | + 0.213642, |
| 157 | + -0.369844, |
| 158 | + 0.180118, |
| 159 | + 0.189726, |
| 160 | + 0.025714, |
| 161 | + -0.079462, |
| 162 | + 0.053748, |
| 163 | + 0.122328, |
| 164 | + -0.238789, |
| 165 | + 0.116460, |
| 166 | + -0.598687, |
| 167 | + 0.302203, |
| 168 | + 0.296484, |
| 169 | + ], |
| 170 | + dtype=dtype, |
| 171 | + ).reshape(2, 4, 3, 3) |
| 172 | + |
| 173 | + data = { |
| 174 | + "logits": logits, |
| 175 | + "targets": targets, |
| 176 | + "src_lengths": src_lengths, |
| 177 | + "tgt_lengths": tgt_lengths, |
| 178 | + "blank": blank, |
| 179 | + } |
| 180 | + |
| 181 | + return data, ref_costs, ref_gradients |
| 182 | + |
| 183 | + |
| 184 | +def numpy_to_torch(data, device, requires_grad=True): |
| 185 | + |
| 186 | + logits = torch.from_numpy(data["logits"]) |
| 187 | + targets = torch.from_numpy(data["targets"]) |
| 188 | + src_lengths = torch.from_numpy(data["src_lengths"]) |
| 189 | + tgt_lengths = torch.from_numpy(data["tgt_lengths"]) |
| 190 | + |
| 191 | + logits.requires_grad_(requires_grad) |
| 192 | + |
| 193 | + logits = logits.to(device) |
| 194 | + |
| 195 | + def grad_hook(grad): |
| 196 | + logits.saved_grad = grad.clone() |
| 197 | + |
| 198 | + logits.register_hook(grad_hook) |
| 199 | + |
| 200 | + data["logits"] = logits |
| 201 | + data["src_lengths"] = src_lengths |
| 202 | + data["tgt_lengths"] = tgt_lengths |
| 203 | + data["targets"] = targets |
| 204 | + |
| 205 | + return data |
| 206 | + |
| 207 | + |
| 208 | +def compute_with_pytorch_transducer(data): |
| 209 | + costs = RNNTLoss(blank=data["blank"], reduction="none")( |
| 210 | + acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"], |
| 211 | + labels=data["targets"], |
| 212 | + act_lens=data["src_lengths"], |
| 213 | + label_lens=data["tgt_lengths"], |
| 214 | + ) |
| 215 | + |
| 216 | + loss = torch.sum(costs) |
| 217 | + loss.backward() |
| 218 | + costs = costs.cpu().data.numpy() |
| 219 | + gradients = data["logits"].saved_grad.cpu().data.numpy() |
| 220 | + return costs, gradients |
| 221 | + |
| 222 | + |
| 223 | +class TransducerTester: |
| 224 | + def test_basic_backward(self): |
| 225 | + # Test if example provided in README runs |
| 226 | + # https://github.com/HawkAaron/warp-transducer |
| 227 | + |
| 228 | + rnnt_loss = RNNTLoss() |
| 229 | + |
| 230 | + acts = torch.FloatTensor( |
| 231 | + [ |
| 232 | + [ |
| 233 | + [ |
| 234 | + [0.1, 0.6, 0.1, 0.1, 0.1], |
| 235 | + [0.1, 0.1, 0.6, 0.1, 0.1], |
| 236 | + [0.1, 0.1, 0.2, 0.8, 0.1], |
| 237 | + ], |
| 238 | + [ |
| 239 | + [0.1, 0.6, 0.1, 0.1, 0.1], |
| 240 | + [0.1, 0.1, 0.2, 0.1, 0.1], |
| 241 | + [0.7, 0.1, 0.2, 0.1, 0.1], |
| 242 | + ], |
| 243 | + ] |
| 244 | + ] |
| 245 | + ) |
| 246 | + labels = torch.IntTensor([[1, 2]]) |
| 247 | + act_length = torch.IntTensor([2]) |
| 248 | + label_length = torch.IntTensor([2]) |
| 249 | + |
| 250 | + acts = acts.to(self.device) |
| 251 | + labels = labels.to(self.device) |
| 252 | + act_length = act_length.to(self.device) |
| 253 | + label_length = label_length.to(self.device) |
| 254 | + |
| 255 | + acts.requires_grad_(True) |
| 256 | + |
| 257 | + loss = rnnt_loss(acts, labels, act_length, label_length) |
| 258 | + loss.backward() |
| 259 | + |
| 260 | + def _test_costs_and_gradients( |
| 261 | + self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 |
| 262 | + ): |
| 263 | + logits_shape = data["logits"].shape |
| 264 | + costs, gradients = compute_with_pytorch_transducer(data=data) |
| 265 | + np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol) |
| 266 | + self.assertEqual(logits_shape, gradients.shape) |
| 267 | + if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol): |
| 268 | + for b in range(len(gradients)): |
| 269 | + T = data["src_lengths"][b] |
| 270 | + U = data["tgt_lengths"][b] |
| 271 | + for t in range(gradients.shape[1]): |
| 272 | + for u in range(gradients.shape[2]): |
| 273 | + np.testing.assert_allclose( |
| 274 | + gradients[b, t, u], |
| 275 | + ref_gradients[b, t, u], |
| 276 | + atol=atol, |
| 277 | + rtol=rtol, |
| 278 | + err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}", |
| 279 | + ) |
| 280 | + |
| 281 | + def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): |
| 282 | + data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32) |
| 283 | + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) |
| 284 | + self._test_costs_and_gradients( |
| 285 | + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients |
| 286 | + ) |
| 287 | + |
| 288 | + |
| 289 | +@common_utils.skipIfNoTransducer |
| 290 | +class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase): |
| 291 | + device = "cpu" |
0 commit comments