|
22 | 22 | ] |
23 | 23 |
|
24 | 24 |
|
| 25 | +def _yuv420_to_444(mat): |
| 26 | + # logic taken from |
| 27 | + # https://en.wikipedia.org/wiki/YUV#Y%E2%80%B2UV420p_(and_Y%E2%80%B2V12_or_YV12)_to_RGB888_conversion |
| 28 | + width = mat.shape[-1] |
| 29 | + height = mat.shape[0] * 2 // 3 |
| 30 | + luma = mat[:height] |
| 31 | + uv = mat[height:].reshape(2, height // 2, width // 2) |
| 32 | + uv2 = torch.nn.functional.interpolate(uv[None], scale_factor=2, mode='nearest')[0] |
| 33 | + yuv2 = torch.cat([luma[None], uv2]).permute(1, 2, 0) |
| 34 | + return yuv2 |
| 35 | + |
| 36 | + |
| 37 | +def _yuv420_to_rgb(mat, limited_color_range=True, standard='bt709'): |
| 38 | + # taken from https://en.wikipedia.org/wiki/YCbCr |
| 39 | + if standard == 'bt601': |
| 40 | + # ITU-R BT.601, as used by decord |
| 41 | + # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion |
| 42 | + m = torch.tensor([[ 1.0000, 0.0000, 1.402], |
| 43 | + [ 1.0000, -(1.772 * 0.114 / 0.587), -(1.402 * 0.299 / 0.587)], |
| 44 | + [ 1.0000, 1.772, 0.0000]], device=mat.device) |
| 45 | + elif standard == 'bt709': |
| 46 | + # ITU-R BT.709 |
| 47 | + # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion |
| 48 | + m = torch.tensor([[ 1.0000, 0.0000, 1.5748], |
| 49 | + [ 1.0000, -0.1873, -0.4681], |
| 50 | + [ 1.0000, 1.8556, 0.0000]], device=mat.device) |
| 51 | + else: |
| 52 | + raise ValueError(f"{standard} not supported") |
| 53 | + |
| 54 | + if limited_color_range: |
| 55 | + # also present in https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion |
| 56 | + # being mentioned as compensation for the footroom and headroom |
| 57 | + m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device) |
| 58 | + |
| 59 | + m = m.T |
| 60 | + |
| 61 | + # TODO: maybe this needs to come together with limited_color_range |
| 62 | + offset = torch.tensor([16., 128., 128.], device=mat.device) |
| 63 | + |
| 64 | + yuv2 = _yuv420_to_444(mat) |
| 65 | + |
| 66 | + res = (yuv2 - offset) @ m |
| 67 | + return res |
| 68 | + |
| 69 | + |
25 | 70 | @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") |
26 | 71 | class TestVideoGPUDecoder: |
27 | 72 | @pytest.mark.skipif(av is None, reason="PyAV unavailable") |
28 | 73 | def test_frame_reading(self): |
29 | 74 | for test_video in test_videos: |
30 | 75 | full_path = os.path.join(VIDEO_DIR, test_video) |
31 | 76 | decoder = VideoReader(full_path, device="cuda:0") |
| 77 | + print(test_video) |
32 | 78 | with av.open(full_path) as container: |
33 | 79 | for av_frame in container.decode(container.streams.video[0]): |
34 | | - av_frames = torch.tensor(av_frame.to_ndarray().flatten()) |
| 80 | + #print(av_frame.format) |
| 81 | + av2 = av_frame.to_rgb().to_ndarray() |
| 82 | + #print(av2.shape) |
| 83 | + av_frames_yuv = torch.tensor(av_frame.to_ndarray()) |
| 84 | + #av_frames = torch.tensor(av_frame.to_rgb().to_ndarray()) |
| 85 | + #av2 = torch.tensor(av_frame.to_rgb(dst_colorspace='ITU709').to_ndarray()) |
| 86 | + #av_frames = torch.tensor(av_frame.to_rgb(dst_colorspace='ITU624').to_ndarray()) |
35 | 87 | vision_frames = next(decoder)["data"] |
36 | | - mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float())) |
37 | | - assert mean_delta < 0.1 |
| 88 | + if False: |
| 89 | + if False: |
| 90 | + rr = decoder._reformat(vision_frames) |
| 91 | + rr = rr.reshape(av_frames.shape) |
| 92 | + rr2 = _transform(rr) |
| 93 | + else: |
| 94 | + rr2 = vision_frames |
| 95 | + print(rr2[:2, :2]) |
| 96 | + print(av2[:2, :2]) |
| 97 | + print(_transform(av_frames)[:2, :2]) |
| 98 | + print((_transform(av_frames) - rr2.cpu()).abs().max()) |
| 99 | + print((_transform(av_frames) - rr2.cpu()).abs().mean()) |
| 100 | + print((_transform(av_frames) - rr2.cpu()).abs().median()) |
| 101 | + print('----------') |
| 102 | + print(torch.max(torch.abs(torch.tensor(av2).float() - rr2.cpu().float()))) |
| 103 | + print(torch.mean(torch.abs(torch.tensor(av2).float() - rr2.cpu().float()))) |
| 104 | + print(torch.median(torch.abs(torch.tensor(av2).float() - rr2.cpu().float()))) |
| 105 | + aa = _yuv444(av_frames).flatten(0, -2) - torch.tensor([16., 128., 128.]) |
| 106 | + bb = torch.tensor(av2).flatten(0, -2).float() |
| 107 | + print('----------') |
| 108 | + rrr = torch.linalg.lstsq(aa, bb) |
| 109 | + print((bb - aa @ rrr.solution).abs().max()) |
| 110 | + print((bb - aa @ rrr.solution).abs().mean()) |
| 111 | + print((bb - aa @ rrr.solution).abs().median()) |
| 112 | + |
| 113 | + #print(rr[:3, :3], av_frames.shape) |
| 114 | + mean_delta = torch.mean(torch.abs(av_frames.float() - rr.float())) |
| 115 | + print(torch.max(torch.abs(av_frames.float() - rr.float()))) |
| 116 | + #mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float())) |
| 117 | + |
| 118 | + #print((av_frames.float() - vision_frames.cpu().float()).abs().max()) |
| 119 | + #print((av_frames.float() - vision_frames.cpu().float()).abs().flatten().topk(10,largest=False).values) |
| 120 | + #v = (av_frames.float() - vision_frames.cpu().float()).abs().flatten() |
| 121 | + #v = torch.histogram(v, bins=v.unique()) |
| 122 | + #print(test_video, (v.hist / v.hist.sum() * 100).int()) |
| 123 | + |
| 124 | + av_frames_rgb = _yuv420_to_rgb(av_frames_yuv) |
| 125 | + #diff = torch.abs(av_frames_rgb.floor().float() - vision_frames.cpu().float()) |
| 126 | + diff = torch.abs(av_frames_rgb.float() - vision_frames.cpu().float()) |
| 127 | + mean_delta = torch.median(diff) |
| 128 | + mean_delta = torch.kthvalue(diff.flatten(), int(diff.numel() * 0.7)).values |
| 129 | + if mean_delta > 16: |
| 130 | + print((torch.abs(diff)).max()) |
| 131 | + print((torch.abs(diff)).median()) |
| 132 | + #v = torch.histogram(diff.flatten(), bins=diff.flatten().unique()) |
| 133 | + v = torch.histogram(diff.flatten(), bins=100) |
| 134 | + print((v.hist / v.hist.sum() * 100).int()) |
| 135 | + print((v.hist / v.hist.sum() * 100).cumsum(0).int()) |
| 136 | + print((v.hist / v.hist.sum() * 100)) |
| 137 | + assert mean_delta < 16 |
| 138 | + #assert mean_delta < 5 |
38 | 139 |
|
39 | 140 |
|
40 | 141 | if __name__ == "__main__": |
|
0 commit comments