|
1 | 1 | #include "decoder.h" |
2 | 2 | #include <c10/util/Logging.h> |
| 3 | +#include <nppi_color_conversion.h> |
3 | 4 | #include <cmath> |
4 | 5 | #include <cstring> |
5 | 6 | #include <unordered_map> |
@@ -138,38 +139,24 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { |
138 | 139 | } |
139 | 140 |
|
140 | 141 | auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); |
141 | | - torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options); |
| 142 | + torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options); |
142 | 143 | uint8_t* frame_ptr = decoded_frame.data_ptr<uint8_t>(); |
| 144 | + const uint8_t* const source_arr[] = { |
| 145 | + (const uint8_t* const)source_frame, |
| 146 | + (const uint8_t* const)(source_frame + source_pitch * ((surface_height + 1) & ~1))}; |
| 147 | + |
| 148 | + auto err = nppiNV12ToRGB_709CSC_8u_P2C3R( |
| 149 | + source_arr, |
| 150 | + source_pitch, |
| 151 | + frame_ptr, |
| 152 | + width * 3, |
| 153 | + {(int)decoded_frame.size(1), (int)decoded_frame.size(0)}); |
| 154 | + |
| 155 | + TORCH_CHECK( |
| 156 | + err == NPP_NO_ERROR, |
| 157 | + "Failed to convert from NV12 to RGB. Error code:", |
| 158 | + err); |
143 | 159 |
|
144 | | - // Copy luma plane |
145 | | - CUDA_MEMCPY2D m = {0}; |
146 | | - m.srcMemoryType = CU_MEMORYTYPE_DEVICE; |
147 | | - m.srcDevice = source_frame; |
148 | | - m.srcPitch = source_pitch; |
149 | | - m.dstMemoryType = CU_MEMORYTYPE_DEVICE; |
150 | | - m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr); |
151 | | - m.dstPitch = get_width() * bytes_per_pixel; |
152 | | - m.WidthInBytes = get_width() * bytes_per_pixel; |
153 | | - m.Height = luma_height; |
154 | | - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); |
155 | | - |
156 | | - // Copy chroma plane |
157 | | - // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning |
158 | | - // height |
159 | | - m.srcDevice = |
160 | | - (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1)); |
161 | | - m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height); |
162 | | - m.Height = chroma_height; |
163 | | - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); |
164 | | - |
165 | | - if (num_chroma_planes == 2) { |
166 | | - m.srcDevice = |
167 | | - (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2); |
168 | | - m.dstDevice = |
169 | | - (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2); |
170 | | - m.Height = chroma_height; |
171 | | - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); |
172 | | - } |
173 | 160 | check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__); |
174 | 161 | decoded_frames.push(decoded_frame); |
175 | 162 | check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); |
|
0 commit comments