Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def get_extensions():
"z",
"pthread",
"dl",
"nppicc",
],
extra_compile_args=extra_compile_args,
)
Expand Down
107 changes: 104 additions & 3 deletions test/test_video_gpu_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,120 @@
]


def _yuv420_to_444(mat):
# logic taken from
# https://en.wikipedia.org/wiki/YUV#Y%E2%80%B2UV420p_(and_Y%E2%80%B2V12_or_YV12)_to_RGB888_conversion
width = mat.shape[-1]
height = mat.shape[0] * 2 // 3
luma = mat[:height]
uv = mat[height:].reshape(2, height // 2, width // 2)
uv2 = torch.nn.functional.interpolate(uv[None], scale_factor=2, mode='nearest')[0]
yuv2 = torch.cat([luma[None], uv2]).permute(1, 2, 0)
return yuv2


def _yuv420_to_rgb(mat, limited_color_range=True, standard='bt709'):
# taken from https://en.wikipedia.org/wiki/YCbCr
if standard == 'bt601':
# ITU-R BT.601, as used by decord
# taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
m = torch.tensor([[ 1.0000, 0.0000, 1.402],
[ 1.0000, -(1.772 * 0.114 / 0.587), -(1.402 * 0.299 / 0.587)],
[ 1.0000, 1.772, 0.0000]], device=mat.device)
elif standard == 'bt709':
# ITU-R BT.709
# taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
m = torch.tensor([[ 1.0000, 0.0000, 1.5748],
[ 1.0000, -0.1873, -0.4681],
[ 1.0000, 1.8556, 0.0000]], device=mat.device)
else:
raise ValueError(f"{standard} not supported")

if limited_color_range:
# also present in https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
# being mentioned as compensation for the footroom and headroom
m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device)

m = m.T

# TODO: maybe this needs to come together with limited_color_range
offset = torch.tensor([16., 128., 128.], device=mat.device)

yuv2 = _yuv420_to_444(mat)

res = (yuv2 - offset) @ m
return res


@pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
class TestVideoGPUDecoder:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_frame_reading(self):
for test_video in test_videos:
full_path = os.path.join(VIDEO_DIR, test_video)
decoder = VideoReader(full_path, device="cuda:0")
print(test_video)
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_ndarray().flatten())
#print(av_frame.format)
av2 = av_frame.to_rgb().to_ndarray()
#print(av2.shape)
av_frames_yuv = torch.tensor(av_frame.to_ndarray())
#av_frames = torch.tensor(av_frame.to_rgb().to_ndarray())
#av2 = torch.tensor(av_frame.to_rgb(dst_colorspace='ITU709').to_ndarray())
#av_frames = torch.tensor(av_frame.to_rgb(dst_colorspace='ITU624').to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))
assert mean_delta < 0.1
if False:
if False:
rr = decoder._reformat(vision_frames)
rr = rr.reshape(av_frames.shape)
rr2 = _transform(rr)
else:
rr2 = vision_frames
print(rr2[:2, :2])
print(av2[:2, :2])
print(_transform(av_frames)[:2, :2])
print((_transform(av_frames) - rr2.cpu()).abs().max())
print((_transform(av_frames) - rr2.cpu()).abs().mean())
print((_transform(av_frames) - rr2.cpu()).abs().median())
print('----------')
print(torch.max(torch.abs(torch.tensor(av2).float() - rr2.cpu().float())))
print(torch.mean(torch.abs(torch.tensor(av2).float() - rr2.cpu().float())))
print(torch.median(torch.abs(torch.tensor(av2).float() - rr2.cpu().float())))
aa = _yuv444(av_frames).flatten(0, -2) - torch.tensor([16., 128., 128.])
bb = torch.tensor(av2).flatten(0, -2).float()
print('----------')
rrr = torch.linalg.lstsq(aa, bb)
print((bb - aa @ rrr.solution).abs().max())
print((bb - aa @ rrr.solution).abs().mean())
print((bb - aa @ rrr.solution).abs().median())

#print(rr[:3, :3], av_frames.shape)
mean_delta = torch.mean(torch.abs(av_frames.float() - rr.float()))
print(torch.max(torch.abs(av_frames.float() - rr.float())))
#mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))

#print((av_frames.float() - vision_frames.cpu().float()).abs().max())
#print((av_frames.float() - vision_frames.cpu().float()).abs().flatten().topk(10,largest=False).values)
#v = (av_frames.float() - vision_frames.cpu().float()).abs().flatten()
#v = torch.histogram(v, bins=v.unique())
#print(test_video, (v.hist / v.hist.sum() * 100).int())

av_frames_rgb = _yuv420_to_rgb(av_frames_yuv)
#diff = torch.abs(av_frames_rgb.floor().float() - vision_frames.cpu().float())
diff = torch.abs(av_frames_rgb.float() - vision_frames.cpu().float())
mean_delta = torch.median(diff)
mean_delta = torch.kthvalue(diff.flatten(), int(diff.numel() * 0.7)).values
if mean_delta > 16:
print((torch.abs(diff)).max())
print((torch.abs(diff)).median())
#v = torch.histogram(diff.flatten(), bins=diff.flatten().unique())
v = torch.histogram(diff.flatten(), bins=100)
print((v.hist / v.hist.sum() * 100).int())
print((v.hist / v.hist.sum() * 100).cumsum(0).int())
print((v.hist / v.hist.sum() * 100))
assert mean_delta < 16
#assert mean_delta < 5


if __name__ == "__main__":
Expand Down
44 changes: 14 additions & 30 deletions torchvision/csrc/io/decoder/gpu/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <cmath>
#include <cstring>
#include <unordered_map>
#include <nppi_color_conversion.h>


static float chroma_height_factor(cudaVideoSurfaceFormat surface_format) {
return (surface_format == cudaVideoSurfaceFormat_YUV444 ||
Expand Down Expand Up @@ -138,38 +140,20 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) {
}

auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA);
torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options);
torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options);
uint8_t* frame_ptr = decoded_frame.data_ptr<uint8_t>();

// Copy luma plane
CUDA_MEMCPY2D m = {0};
m.srcMemoryType = CU_MEMORYTYPE_DEVICE;
m.srcDevice = source_frame;
m.srcPitch = source_pitch;
m.dstMemoryType = CU_MEMORYTYPE_DEVICE;
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr);
m.dstPitch = get_width() * bytes_per_pixel;
m.WidthInBytes = get_width() * bytes_per_pixel;
m.Height = luma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);

// Copy chroma plane
// NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning
// height
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1));
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);

if (num_chroma_planes == 2) {
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2);
m.dstDevice =
(CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
}
// TODO: check the surface_height condition in here
const uint8_t *const pSrc[] = {(const uint8_t *const)source_frame,
(const uint8_t *const)(source_frame + source_pitch * ((surface_height + 1) & ~1))};


// TODO: create and reuse NppStreamContext, and thus need to use nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx instead
auto err = nppiNV12ToRGB_709CSC_8u_P2C3R(pSrc, source_pitch, frame_ptr,
width * 3, {(int)decoded_frame.size(1), (int)decoded_frame.size(0)});

TORCH_CHECK(err == NPP_NO_ERROR, "Failed to convert from NV12 to RGB. Error code:", err);

check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__);
decoded_frames.push(decoded_frame);
check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__);
Expand Down
42 changes: 1 addition & 41 deletions torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,48 +38,8 @@ torch::Tensor GPUDecoder::decode() {
return frame;
}

/* Convert a tensor with data in NV12 format to a tensor with data in YUV420
* format in-place.
*/
torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) {
int width = decoder.get_width(), height = decoder.get_height();
int pitch = width;
uint8_t* frame = frameTensor.data_ptr<uint8_t>();
uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)];

// sizes of source surface plane
int sizePlaneY = pitch * height;
int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2);
int sizePlaneV = sizePlaneU;

uint8_t* uv = frame + sizePlaneY;
uint8_t* u = uv;
uint8_t* v = uv + sizePlaneU;

// split chroma from interleave to planar
for (int y = 0; y < (height + 1) / 2; y++) {
for (int x = 0; x < (width + 1) / 2; x++) {
u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2];
ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1];
}
}
if (pitch == width) {
memcpy(v, ptr, sizePlaneV * sizeof(uint8_t));
} else {
for (int i = 0; i < (height + 1) / 2; i++) {
memcpy(
v + ((pitch + 1) / 2) * i,
ptr + ((width + 1) / 2) * i,
((width + 1) / 2) * sizeof(uint8_t));
}
}
delete[] ptr;
return frameTensor;
}

TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def("next", &GPUDecoder::decode)
.def("reformat", &GPUDecoder::nv12_to_yuv420);
.def("next", &GPUDecoder::decode);
}
10 changes: 0 additions & 10 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,6 @@ def set_current_stream(self, stream: str) -> bool:
print("GPU decoding only works with video stream.")
return self._c.set_current_stream(stream)

def _reformat(self, tensor, output_format: str = "yuv420"):
supported_formats = [
"yuv420",
]
if output_format not in supported_formats:
raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}")
if not isinstance(tensor, torch.Tensor):
raise RuntimeError("Expected tensor as input parameter!")
return self._c.reformat(tensor.cpu())


__all__ = [
"write_video",
Expand Down