From 285a7e24a7749f52b0311d2604266fdb3072b284 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 30 Apr 2021 14:14:54 -0700 Subject: [PATCH 1/5] Add GPU RNNT Loss Summary: In #1479, we added support for CPU RNNT loss. This PR adds a GPU version of RNNT loss Differential Revision: D28128853 fbshipit-source-id: 3c610c7e9c3dda3fb309586d5dc71397752cd2e0 --- .../rnnt/rnnt_loss_cuda_test.py | 10 + torchaudio/csrc/rnnt/gpu/compute.cu | 105 +++++ torchaudio/csrc/rnnt/gpu/compute_alphas.cu | 73 +++ torchaudio/csrc/rnnt/gpu/compute_betas.cu | 78 ++++ torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh | 99 ++++ torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh | 431 ++++++++++++++++++ torchaudio/csrc/rnnt/gpu/gpu_transducer.h | 393 ++++++++++++++++ torchaudio/csrc/rnnt/gpu/half.cuh | 38 ++ torchaudio/csrc/rnnt/gpu/kernel_utils.h | 59 +++ torchaudio/csrc/rnnt/gpu/kernels.h | 133 ++++++ torchaudio/csrc/rnnt/gpu/math.cuh | 41 ++ 11 files changed, 1460 insertions(+) create mode 100644 test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py create mode 100644 torchaudio/csrc/rnnt/gpu/compute.cu create mode 100644 torchaudio/csrc/rnnt/gpu/compute_alphas.cu create mode 100644 torchaudio/csrc/rnnt/gpu/compute_betas.cu create mode 100644 torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh create mode 100644 torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh create mode 100644 torchaudio/csrc/rnnt/gpu/gpu_transducer.h create mode 100644 torchaudio/csrc/rnnt/gpu/half.cuh create mode 100644 torchaudio/csrc/rnnt/gpu/kernel_utils.h create mode 100644 torchaudio/csrc/rnnt/gpu/kernels.h create mode 100644 torchaudio/csrc/rnnt/gpu/math.cuh diff --git a/test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py b/test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py new file mode 100644 index 0000000000..c64fd687fe --- /dev/null +++ b/test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py @@ -0,0 +1,10 @@ +import torch +from .rnnt_loss_impl import RNNTLossTest +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer + + +@skipIfNoTransducer +@common_utils.skipIfNoCuda +class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase): + device = torch.device('cuda') diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu new file mode 100644 index 0000000000..8582f36d48 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -0,0 +1,105 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +// Entry point into RNNT Loss +std::tuple> +compute( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_smax; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor costs = torch::empty( + options.batchSize_ * options.nHypos_, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + c10::optional gradients = c10::nullopt; + if (logits.requires_grad()) { + if (reuse_logits_for_grads) { + gradients = logits; + } else { + gradients = torch::zeros_like(logits); + } + } + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + switch (logits.type().scalarType()) { + case torch::ScalarType::Float: + { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); + break; + } + case torch::ScalarType::Half: + { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); + break; + } + default: + { + LOG(ERROR) << "unsupported logits.type().scalarType() = " + << logits.type().scalarType(); + break; + } + }; + + return std::make_tuple(costs, gradients); +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss", &compute); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu new file mode 100644 index 0000000000..dc101b0ecb --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu @@ -0,0 +1,73 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +torch::Tensor compute_alphas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor alphas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*alphas=*/alphas.data()); + return alphas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_alphas", &compute_alphas); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu new file mode 100644 index 0000000000..63b8fd636a --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -0,0 +1,78 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +torch::Tensor compute_betas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = src_lengths.size(0); + options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor costs = torch::empty( + tgt_lengths.size(0), + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor betas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*betas=*/betas.data()); + return betas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_betas", &compute_betas); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh new file mode 100644 index 0000000000..f42e35b46f --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh @@ -0,0 +1,99 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +namespace torchaudio { +namespace rnnt { + +template +__global__ void ReduceMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { + + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + // each thread reduces one matrix row + int offset = blockIdx.x * dim; // [n, 0] + CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + CAST_DTYPE next = inputs[offset + d]; + if (next > val) { + val = next; + } + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shared[threadIdx.x + stride] > shared[threadIdx.x]) { + shared[threadIdx.x] = shared[threadIdx.x + stride]; + val = shared[threadIdx.x]; + } + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shf > val) { + val = shf; + } + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = val; + } +} + +template +__global__ void ReduceLogSumExpGivenMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { // in: max -> out: logsum + + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + CAST_DTYPE max = outputs[blockIdx.x]; + CAST_DTYPE val = 0; + + int offset = blockIdx.x * dim; + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max); + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = shared[threadIdx.x] + shared[threadIdx.x + stride]; + shared[threadIdx.x] = val; + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = val + shf; + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = max + std::log(val); + } +} + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh new file mode 100644 index 0000000000..a2d724bf8f --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh @@ -0,0 +1,431 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { + + +template +__global__ void ComputeLogProbs( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + CAST_DTYPE* logProbs, + int H=1, + bool fusedLogSmax=true) { + + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + const int& D = numTargets; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + Indexer3D indexer(maxT, maxU); + + int idx = indexer(bTgt, t, u); + + // skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u). + logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = + CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; + + if (!fusedLogSmax) { + logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = + CAST_DTYPE(logits[idx * D + blank]); + } + + if (u < U - 1) { + // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, u). + int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; + logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = + CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; + + if (!fusedLogSmax) { + logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = + CAST_DTYPE(logits[idx * D + target]); + } + } + + +} + + +template +__device__ void ComputeAlphas( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H=1) { + + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x + 1; + const int u = blockIdx.y + 1; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == 1 && u == 1) { + alphas[idxr(bTgt, 0, 0)] = 0; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) {} + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) {} + } + + if (t == 1 && u < U) { + + // alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit(). + alphas[idxr(bTgt, 0, u)] = + alphas[idxr(bTgt, 0, u - 1)] + + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t < T) { + CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for (int i = 1; i < warpSize; i <<= 1) { + val = __shfl_up_sync(0xffffffff, skip_prob, i); + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)]; + alphas[idxr(bTgt, t, 0)] = skip_prob + val; + } + + if (t < T && u < U) { + + CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; + CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for(int i = 1; i < warpSize; ++i) { + val = __shfl_up_sync(0xffffffff, val, 1); + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + alphas[idxr(bTgt, t, u)] = out; + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + + +template +__device__ void ComputeBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H=1) { + + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x; + const int u = U - 2 - blockIdx.y; + + if (t < 0 || u < 0) { // out of boundary. + return; + } + + int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == T - 2 && u == U - 2) { + betas[idxr(bTgt, T - 1, U - 1)] = + logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) {} + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) {} + } + + if (t == T - 2 && u >= 0) { + + betas[idxr(bTgt, T - 1, u)] = + betas[idxr(bTgt, T - 1, u + 1)] + + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t >= 0) { + CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for(int i = 1; i < warpSize; i <<= 1) { + val = __shfl_up_sync(0xffffffff, skip_prob, i); + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + betas[idxr(bTgt, t, U - 1)] = + betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob; + } + + if (t >= 0 && u >= 0) { + + CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob; + CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for(int i = 1; i < warpSize; ++i) { + val = __shfl_up_sync(0xffffffff, val, 1); + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + betas[idxr(bTgt, t, u)] = out; + + if (t == 0 && u == 0) { // use -beta(0, 0) as cost. + costs[bTgt] = DTYPE(-out); + } + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + + +template +__global__ void ComputeAlphasBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int warpSize = 0, + int numWarps=0, + int H = 1) { + + assert(threadIdx.y == 0 || threadIdx.y == 1); + + if (threadIdx.y == 0) { + + ComputeAlphas( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/alpha_counters, + /*alphas=*/alphas, + H); + } else { // threadIdx.y == 1 + ComputeBetasCosts( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*betaCounters=*/betaCounters, + /*beta=*/betas, + /*costs=*/costs, + H); + } +} + + +template +__global__ void ComputeGradients( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + CAST_DTYPE clamp, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients, + int H = 1, + bool fusedLogSmax = true) { + + const int bTgt = blockIdx.z; // 0 <= b < B + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + ComputeGradientsElement( + bTgt, + t, + u, + maxSrcLen, + maxTgtLen, + numTargets, + blank, + clamp, + logits, + targets, + srcLengths, + tgtLengths, + denominators, + alphas, + betas, + gradients, + H, + fusedLogSmax); +} + + +// This is a __global__ wrapper around ComputeAlphas +// device kernel to enable unit testing +template +__global__ void ComputeAlphasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H=1) { + ComputeAlphas( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + alpha_counters, + alphas, + H); +} + +// This is a __global__ wrapper around ComputeBetas +// device kernel to enable unit testing +template +__global__ void ComputeBetasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H=1) { + ComputeBetasCosts( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + betaCounters, + betas, + costs, + H); +} + + +// #undef LOG_PROBS_SKIP_IDX +// #undef LOG_PROBS_EMIT_IDX + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h new file mode 100644 index 0000000000..1720b55645 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h @@ -0,0 +1,393 @@ +#pragma once + +#ifdef USE_CUDA + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } + +inline void +gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { + if (code != cudaSuccess) { + fprintf( + stderr, + "\nGPUassert: %s %s %d\n", + cudaGetErrorString(code), + file, + line); + if (abort) + exit(code); + } +} + +template +status_t LogSumExp2D( + cudaStream_t stream, + int N, + int D, + const DTYPE* logits, // [N, D] + CAST_DTYPE* outputs) { + { // compute max among D. + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + ReduceMax2D + <<>>( + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + // BUGBUG: These error codes are only accurate when launching with + // blocking. Otherwise they usually reflect earlier errors. + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED; + } + } + + { // compute log(sum(exp(d_i - max))) + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + ReduceLogSumExpGivenMax2D + <<>>( + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED; + } + } + + return SUCCESS; +} + +// Inputs: +// workspace: workspace. +// logits: pointer to (B, max_T, max_U, D) logits. +// targets: pointer to (B, max_U - 1) targets in the batch. +// srcLengths: pointer to (B, ) source lengths in the batch. +// tgtLengths: pointer to (B, ) target lengths in the batch. +// +// Outputs: +// costs: pointer to (B, ) costs in the batch. +// gradients: pointer to (B, max_T, max_U, D) gradients in the batch. +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + const CAST_DTYPE clamp = options.clamp_; + + const bool& fusedLogSmax = options.fusedLogSmax_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H, + fusedLogSmax); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + + { // compute alphas, betas and costs. + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B * H blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 2. 1 for alpha, 1 for beta + dim3 thread_dims(WARP_SIZE, 2); + + ComputeAlphasBetasCosts + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*beta_counters=*/workspace.GetPointerToBetaCounters(), + /*betas=*/workspace.GetPointerToBetas(), + /*costs=*/costs, + /*warp_size=*/WARP_SIZE, + /*num_warps=*/num_warps, + H); + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + if (gradients != nullptr) { // compute gradients. + // don't set gradients to zero to here as gradients might reuse memory from + // logits + + int num_blocks = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_blocks, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeGradients<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*clamp=*/clamp, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*gradients=*/gradients, + H, + fusedLogSmax); + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_GRADIENTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute alphas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for alpha only + dim3 thread_dims(WARP_SIZE, 1); + + + ComputeAlphasWrapper + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/(volatile DTYPE*)alphas, + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute betas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for betas only + dim3 thread_dims(WARP_SIZE, 1); + + ComputeBetasWrapper + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToBetaCounters(), + /*alphas=*/(volatile DTYPE*)betas, + costs, + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/half.cuh b/torchaudio/csrc/rnnt/gpu/half.cuh new file mode 100644 index 0000000000..d49ac3bdcf --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/half.cuh @@ -0,0 +1,38 @@ +#pragma once + +#ifdef USE_C10_HALF +#include "c10/util/Half.h" +#endif // USE_C10_HALF + +#include + +namespace torchaudio { +namespace rnnt { + +struct alignas(sizeof(__half)) Half { + __half x; + + HOST_AND_DEVICE Half() = default; + + FORCE_INLINE HOST_AND_DEVICE Half(float f) { + x = __float2half_rn(f); + if (isinf(__half2float(x))) { + x = __float2half_rz(f); // round toward 0. + } + } + + FORCE_INLINE HOST_AND_DEVICE operator float() const { + return __half2float(x); + } + + FORCE_INLINE HOST_AND_DEVICE Half(__half f) { + x = f; + } + + FORCE_INLINE HOST_AND_DEVICE operator __half() const { + return x; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernel_utils.h b/torchaudio/csrc/rnnt/gpu/kernel_utils.h new file mode 100644 index 0000000000..03eb3d33b4 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/kernel_utils.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include + +namespace torchaudio { +namespace rnnt { + +inline HOST_AND_DEVICE bool in_range( + int start, + int end, // inclusive + int val) { + return start <= val && val <= end; +} + +#define LOG_PROBS_SKIP_IDX 0 +#define LOG_PROBS_EMIT_IDX 1 + + +struct Indexer2D { + const int& size2_; + + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2): size2_(size2) {} + + FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2) { + return index1 * size2_ + index2; + } +}; + + +struct Indexer3D { + const int& size2_; + const int& size3_; + + FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) + : size2_(size2), size3_(size3) {} + + FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2, int index3) { + return (index1 * size2_ + index2) * size3_ + index3; + } +}; + + +struct Indexer4D { + const int& size2_; + const int& size3_; + const int& size4_; + + HOST_AND_DEVICE Indexer4D(const int& size2, const int& size3, const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator() (int index1, int index2, int index3, int index4) { + return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h new file mode 100644 index 0000000000..2cbef51327 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -0,0 +1,133 @@ +#pragma once + +#include + +#include +#include + +namespace torchaudio { +namespace rnnt { + + +template +HOST_AND_DEVICE void ComputeGradientsElement( + int bTgt, + int t, + int u, + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + CAST_DTYPE clamp, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients, + int H = 1, + bool fusedLogSmax = true) { + + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + const int& D = numTargets; + + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + if (t >= T || u >= U) { // out of boundary. + if (gradients == logits && t < maxT && u < maxU) { + // gradients and logits are pointing to the same memory location + Indexer3D idxr3(maxT, maxU); + int idx_b_t_u_zero = idxr3(bTgt, t, u); + if (idx_b_t_u_zero != -1 ) { + int start = idx_b_t_u_zero * D; + for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) { + gradients[b_t_u_d] = 0; + } + } + } + return; + } + + int costIdx = bTgt * maxT * maxU; + CAST_DTYPE cost = -(betas[costIdx]); + + + Indexer2D idxr2(maxU - 1); + + int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u, idx_b_tp1_up1; + Indexer3D idxr3(maxT, maxU); + idx_b_t_u = idxr3(bTgt, t, u); + idx_b_t_up1 = idxr3(bTgt, t, u+1); + idx_b_tp1_u = idxr3(bTgt, t+1, u); + idx_b_tp1_up1 = idxr3(bTgt, t+1, u+1); + + if (idx_b_t_u == -1 ) { + return; + } + + if (isinf(cost) || isnan(cost)) { + for (int d = 0; d < D; ++d) { + int b_t_u_d = idx_b_t_u * D + d; + gradients[b_t_u_d] = 0; + } + return; + } + + CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u]; + for (int d = 0; d < D; ++d) { + int b_t_u_d = idx_b_t_u * D + d; + CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c; + + if (fusedLogSmax) { + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); + } else if (t < T - 1 && d == blank) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_tp1_u != -1) { + gradients[b_t_u_d] = gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); + } + } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_t_up1 != -1) { + gradients[b_t_u_d] = gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); + } + } else { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + } + } else { // Non fused log softmax case + CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]); + if (d == blank && t == T - 1 && u == U - 1) { + gradients[b_t_u_d] = g + alphas[idx_b_t_u]; + } else if (t < T - 1 && d == blank) { + if (idx_b_tp1_u != -1) { + gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u]; + } else { + gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); + } + } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { + if (idx_b_t_up1 != -1) { + gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1]; + } else { + gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); + } + } else { + gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); + } + gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]); + } + + if (clamp > 0) { + auto g = CAST_DTYPE(gradients[b_t_u_d]); + gradients[b_t_u_d] = math::min(g, clamp); + gradients[b_t_u_d] = math::max(g, -clamp); + } + } +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/math.cuh b/torchaudio/csrc/rnnt/gpu/math.cuh new file mode 100644 index 0000000000..a2eaabff93 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/math.cuh @@ -0,0 +1,41 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +#endif // USE_CUDA + +#include + +namespace torchaudio { +namespace rnnt { + +namespace math { + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { + if (x > y) return x; + else return y; +} + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { + if (x > y) return y; + else return x; +} + +// log_sum_exp +template +FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); + +template <> +FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { + if (y > x) { return y + log1pf(expf(x - y)); } + else { return x + log1pf(expf(y-x)); } +} + +} + +} // namespace rnnt +} // namespace torchaudio From f6e783d18c3a4f5745f074200386c5c7753dc16e Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 30 Apr 2021 15:02:42 -0700 Subject: [PATCH 2/5] clang-format --- torchaudio/csrc/rnnt/gpu/compute.cu | 77 +++++------ torchaudio/csrc/rnnt/gpu/compute_betas.cu | 2 +- torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh | 13 +- torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh | 127 ++++++++---------- torchaudio/csrc/rnnt/gpu/gpu_transducer.h | 10 +- torchaudio/csrc/rnnt/gpu/half.cuh | 8 +- torchaudio/csrc/rnnt/gpu/kernel_utils.h | 33 +++-- torchaudio/csrc/rnnt/gpu/kernels.h | 31 +++-- torchaudio/csrc/rnnt/gpu/math.cuh | 27 ++-- 9 files changed, 167 insertions(+), 161 deletions(-) diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index 8582f36d48..cb12c369ad 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -7,8 +7,7 @@ namespace rnnt { namespace gpu { // Entry point into RNNT Loss -std::tuple> -compute( +std::tuple> compute( torch::Tensor& logits, const torch::Tensor& targets, const torch::Tensor& src_lengths, @@ -17,7 +16,6 @@ compute( double clamp, bool fused_log_smax = true, bool reuse_logits_for_grads = true) { - Options options; options.batchSize_ = src_lengths.size(0); options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); @@ -47,11 +45,15 @@ compute( torch::Tensor int_workspace = torch::empty( IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Int)); + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); torch::Tensor float_workspace = torch::empty( DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions().device(logits.device()).dtype(torch::ScalarType::Float)); + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); Workspace workspace( /*options=*/options, @@ -61,36 +63,35 @@ compute( /*int_size=*/int_workspace.numel()); switch (logits.type().scalarType()) { - case torch::ScalarType::Float: - { - Compute( - /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); - break; - } - case torch::ScalarType::Half: - { - Compute( - /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*gradients=*/(gradients == c10::nullopt)? nullptr : gradients->data()); - break; - } - default: - { - LOG(ERROR) << "unsupported logits.type().scalarType() = " - << logits.type().scalarType(); - break; - } + case torch::ScalarType::Float: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/ + (gradients == c10::nullopt) ? nullptr : gradients->data()); + break; + } + case torch::ScalarType::Half: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data(), + /*targets=*/targets.data(), + /*src_lengths=*/src_lengths.data(), + /*tgt_lengths=*/tgt_lengths.data(), + /*costs=*/costs.data(), + /*gradients=*/ + (gradients == c10::nullopt) ? nullptr : gradients->data()); + break; + } + default: { + LOG(ERROR) << "unsupported logits.type().scalarType() = " + << logits.type().scalarType(); + break; + } }; return std::make_tuple(costs, gradients); @@ -100,6 +101,6 @@ TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("rnnt_loss", &compute); } -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu index 63b8fd636a..f8f85337db 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_betas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -73,6 +73,6 @@ TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("rnnt_loss_betas", &compute_betas); } -} // namespace gpu +} // namespace gpu } // namespace rnnt } // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh index f42e35b46f..e5f1cfc2df 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh @@ -12,12 +12,11 @@ __global__ void ReduceMax2D( int dim, const DTYPE* inputs, // [N, dim] CAST_DTYPE* outputs) { - __shared__ CAST_DTYPE shared[NUM_THREADS]; // each thread reduces one matrix row - int offset = blockIdx.x * dim; // [n, 0] - CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) + int offset = blockIdx.x * dim; // [n, 0] + CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { CAST_DTYPE next = inputs[offset + d]; if (next > val) { @@ -57,7 +56,7 @@ template __global__ void ReduceLogSumExpGivenMax2D( int dim, const DTYPE* inputs, // [N, dim] - CAST_DTYPE* outputs) { // in: max -> out: logsum + CAST_DTYPE* outputs) { // in: max -> out: logsum __shared__ CAST_DTYPE shared[NUM_THREADS]; @@ -93,7 +92,7 @@ __global__ void ReduceLogSumExpGivenMax2D( } } -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio -#endif // USE_CUDA +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh index a2d724bf8f..4ba04b68fc 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh @@ -4,14 +4,13 @@ #include -#include -#include #include +#include +#include namespace torchaudio { namespace rnnt { - template __global__ void ComputeLogProbs( int maxSrcLen, @@ -24,14 +23,13 @@ __global__ void ComputeLogProbs( const int* tgtLengths, const CAST_DTYPE* denominators, CAST_DTYPE* logProbs, - int H=1, - bool fusedLogSmax=true) { - + int H = 1, + bool fusedLogSmax = true) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int& D = numTargets; - const int bTgt = blockIdx.z; // 0 <= b < B + const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; @@ -53,25 +51,23 @@ __global__ void ComputeLogProbs( if (!fusedLogSmax) { logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = - CAST_DTYPE(logits[idx * D + blank]); + CAST_DTYPE(logits[idx * D + blank]); } if (u < U - 1) { - // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, u). + // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, + // u). int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; if (!fusedLogSmax) { logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = - CAST_DTYPE(logits[idx * D + target]); + CAST_DTYPE(logits[idx * D + target]); } } - - } - template __device__ void ComputeAlphas( int maxSrcLen, @@ -83,12 +79,11 @@ __device__ void ComputeAlphas( const int* tgtLengths, int* alpha_counters, volatile CAST_DTYPE* alphas, - int H=1) { - + int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; - const int bTgt = blockIdx.z; // 0 <= b < B + const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; @@ -108,24 +103,25 @@ __device__ void ComputeAlphas( alphas[idxr(bTgt, 0, 0)] = 0; } - if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. - while (atomicAdd(counter, 0) < blockIdx.x) {} + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } } - if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. - while (atomicAdd(counter - 1, 0) <= blockIdx.x) {} + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } } if (t == 1 && u < U) { - // alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit(). - alphas[idxr(bTgt, 0, u)] = - alphas[idxr(bTgt, 0, u - 1)] - + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] + + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; } if (blockIdx.y == 0 && t < T) { - CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE val; #pragma unroll @@ -141,17 +137,19 @@ __device__ void ComputeAlphas( } if (t < T && u < U) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; - CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; - CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; - - CAST_DTYPE skip = alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; + CAST_DTYPE skip = + alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob; CAST_DTYPE val = math::lse(skip, emit); CAST_DTYPE out = val; - for(int i = 1; i < warpSize; ++i) { + for (int i = 1; i < warpSize; ++i) { val = __shfl_up_sync(0xffffffff, val, 1); if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); @@ -168,7 +166,6 @@ __device__ void ComputeAlphas( } } - template __device__ void ComputeBetasCosts( int maxSrcLen, @@ -181,12 +178,11 @@ __device__ void ComputeBetasCosts( int* betaCounters, volatile CAST_DTYPE* betas, DTYPE* costs, - int H=1) { - + int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; - const int bTgt = blockIdx.z; // 0 <= b < B + const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; @@ -207,27 +203,28 @@ __device__ void ComputeBetasCosts( logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; } - if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. - while (atomicAdd(counter, 0) < blockIdx.x) {} + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } } - if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. - while (atomicAdd(counter - 1, 0) <= blockIdx.x) {} + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } } if (t == T - 2 && u >= 0) { - - betas[idxr(bTgt, T - 1, u)] = - betas[idxr(bTgt, T - 1, u + 1)] - + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; + betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] + + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; } if (blockIdx.y == 0 && t >= 0) { - CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE val; #pragma unroll - for(int i = 1; i < warpSize; i <<= 1) { + for (int i = 1; i < warpSize; i <<= 1) { val = __shfl_up_sync(0xffffffff, skip_prob, i); if (i <= threadIdx.x) { skip_prob = skip_prob + val; @@ -239,9 +236,10 @@ __device__ void ComputeBetasCosts( } if (t >= 0 && u >= 0) { - - CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; - CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob; CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob; @@ -249,17 +247,17 @@ __device__ void ComputeBetasCosts( CAST_DTYPE val = math::lse(skip, emit); CAST_DTYPE out = val; - for(int i = 1; i < warpSize; ++i) { + for (int i = 1; i < warpSize; ++i) { val = __shfl_up_sync(0xffffffff, val, 1); if (i == threadIdx.x) { - val = math::lse(val + skip_prob, emit); - out = val; + val = math::lse(val + skip_prob, emit); + out = val; } } betas[idxr(bTgt, t, u)] = out; - if (t == 0 && u == 0) { // use -beta(0, 0) as cost. + if (t == 0 && u == 0) { // use -beta(0, 0) as cost. costs[bTgt] = DTYPE(-out); } } @@ -270,7 +268,6 @@ __device__ void ComputeBetasCosts( } } - template __global__ void ComputeAlphasBetasCosts( int maxSrcLen, @@ -286,13 +283,11 @@ __global__ void ComputeAlphasBetasCosts( volatile CAST_DTYPE* betas, DTYPE* costs, int warpSize = 0, - int numWarps=0, + int numWarps = 0, int H = 1) { - assert(threadIdx.y == 0 || threadIdx.y == 1); if (threadIdx.y == 0) { - ComputeAlphas( /*maxSrcLen=*/maxSrcLen, /*maxTgtLen=*/maxTgtLen, @@ -304,7 +299,7 @@ __global__ void ComputeAlphasBetasCosts( /*alpha_counters=*/alpha_counters, /*alphas=*/alphas, H); - } else { // threadIdx.y == 1 + } else { // threadIdx.y == 1 ComputeBetasCosts( /*maxSrcLen=*/maxSrcLen, /*maxTgtLen=*/maxTgtLen, @@ -320,7 +315,6 @@ __global__ void ComputeAlphasBetasCosts( } } - template __global__ void ComputeGradients( int maxSrcLen, @@ -338,8 +332,7 @@ __global__ void ComputeGradients( DTYPE* gradients, int H = 1, bool fusedLogSmax = true) { - - const int bTgt = blockIdx.z; // 0 <= b < B + const int bTgt = blockIdx.z; // 0 <= b < B const int t = blockIdx.x * blockDim.x + threadIdx.x; const int u = blockIdx.y; @@ -364,7 +357,6 @@ __global__ void ComputeGradients( fusedLogSmax); } - // This is a __global__ wrapper around ComputeAlphas // device kernel to enable unit testing template @@ -378,8 +370,8 @@ __global__ void ComputeAlphasWrapper( const int* tgtLengths, int* alpha_counters, volatile CAST_DTYPE* alphas, - int H=1) { - ComputeAlphas( + int H = 1) { + ComputeAlphas( maxSrcLen, maxTgtLen, numTargets, @@ -406,8 +398,8 @@ __global__ void ComputeBetasWrapper( int* betaCounters, volatile CAST_DTYPE* betas, DTYPE* costs, - int H=1) { - ComputeBetasCosts( + int H = 1) { + ComputeBetasCosts( maxSrcLen, maxTgtLen, numTargets, @@ -421,11 +413,10 @@ __global__ void ComputeBetasWrapper( H); } - // #undef LOG_PROBS_SKIP_IDX // #undef LOG_PROBS_EMIT_IDX -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio -#endif // USE_CUDA +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h index 1720b55645..72759b39f4 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h +++ b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h @@ -2,9 +2,9 @@ #ifdef USE_CUDA +#include #include #include -#include namespace torchaudio { namespace rnnt { @@ -13,8 +13,11 @@ namespace gpu { #define gpuErrchk(ans) \ { gpuAssert((ans), __FILE__, __LINE__); } -inline void -gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { +inline void gpuAssert( + cudaError_t code, + const char* file, + int line, + bool abort = true) { if (code != cudaSuccess) { fprintf( stderr, @@ -274,7 +277,6 @@ status_t ComputeAlphas( // 2nd dim is 1 for alpha only dim3 thread_dims(WARP_SIZE, 1); - ComputeAlphasWrapper <<>>( /*max_src_len=*/max_T, diff --git a/torchaudio/csrc/rnnt/gpu/half.cuh b/torchaudio/csrc/rnnt/gpu/half.cuh index d49ac3bdcf..72a2f37e04 100644 --- a/torchaudio/csrc/rnnt/gpu/half.cuh +++ b/torchaudio/csrc/rnnt/gpu/half.cuh @@ -2,7 +2,7 @@ #ifdef USE_C10_HALF #include "c10/util/Half.h" -#endif // USE_C10_HALF +#endif // USE_C10_HALF #include @@ -17,7 +17,7 @@ struct alignas(sizeof(__half)) Half { FORCE_INLINE HOST_AND_DEVICE Half(float f) { x = __float2half_rn(f); if (isinf(__half2float(x))) { - x = __float2half_rz(f); // round toward 0. + x = __float2half_rz(f); // round toward 0. } } @@ -34,5 +34,5 @@ struct alignas(sizeof(__half)) Half { } }; -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernel_utils.h b/torchaudio/csrc/rnnt/gpu/kernel_utils.h index 03eb3d33b4..3b2989b073 100644 --- a/torchaudio/csrc/rnnt/gpu/kernel_utils.h +++ b/torchaudio/csrc/rnnt/gpu/kernel_utils.h @@ -17,43 +17,50 @@ inline HOST_AND_DEVICE bool in_range( #define LOG_PROBS_SKIP_IDX 0 #define LOG_PROBS_EMIT_IDX 1 - struct Indexer2D { const int& size2_; - FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2): size2_(size2) {} + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {} - FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2) { + FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) { return index1 * size2_ + index2; } }; - struct Indexer3D { const int& size2_; const int& size3_; FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) - : size2_(size2), size3_(size3) {} + : size2_(size2), size3_(size3) {} - FORCE_INLINE HOST_AND_DEVICE int operator() (int index1, int index2, int index3) { + FORCE_INLINE HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3) { return (index1 * size2_ + index2) * size3_ + index3; } }; - struct Indexer4D { const int& size2_; const int& size3_; const int& size4_; - HOST_AND_DEVICE Indexer4D(const int& size2, const int& size3, const int& size4) - : size2_(size2), size3_(size3), size4_(size4) {} - - HOST_AND_DEVICE int operator() (int index1, int index2, int index3, int index4) { + HOST_AND_DEVICE Indexer4D( + const int& size2, + const int& size3, + const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3, + int index4) { return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; } }; -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h index 2cbef51327..97093a34c0 100644 --- a/torchaudio/csrc/rnnt/gpu/kernels.h +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -2,13 +2,12 @@ #include -#include #include +#include namespace torchaudio { namespace rnnt { - template HOST_AND_DEVICE void ComputeGradientsElement( int bTgt, @@ -29,7 +28,6 @@ HOST_AND_DEVICE void ComputeGradientsElement( DTYPE* gradients, int H = 1, bool fusedLogSmax = true) { - const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int& D = numTargets; @@ -43,7 +41,7 @@ HOST_AND_DEVICE void ComputeGradientsElement( // gradients and logits are pointing to the same memory location Indexer3D idxr3(maxT, maxU); int idx_b_t_u_zero = idxr3(bTgt, t, u); - if (idx_b_t_u_zero != -1 ) { + if (idx_b_t_u_zero != -1) { int start = idx_b_t_u_zero * D; for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) { gradients[b_t_u_d] = 0; @@ -56,17 +54,16 @@ HOST_AND_DEVICE void ComputeGradientsElement( int costIdx = bTgt * maxT * maxU; CAST_DTYPE cost = -(betas[costIdx]); - Indexer2D idxr2(maxU - 1); int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u, idx_b_tp1_up1; Indexer3D idxr3(maxT, maxU); idx_b_t_u = idxr3(bTgt, t, u); - idx_b_t_up1 = idxr3(bTgt, t, u+1); - idx_b_tp1_u = idxr3(bTgt, t+1, u); - idx_b_tp1_up1 = idxr3(bTgt, t+1, u+1); + idx_b_t_up1 = idxr3(bTgt, t, u + 1); + idx_b_tp1_u = idxr3(bTgt, t + 1, u); + idx_b_tp1_up1 = idxr3(bTgt, t + 1, u + 1); - if (idx_b_t_u == -1 ) { + if (idx_b_t_u == -1) { return; } @@ -84,17 +81,19 @@ HOST_AND_DEVICE void ComputeGradientsElement( CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c; if (fusedLogSmax) { - if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); } else if (t < T - 1 && d == blank) { gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); if (idx_b_tp1_u != -1) { - gradients[b_t_u_d] = gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); } } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); if (idx_b_t_up1 != -1) { - gradients[b_t_u_d] = gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); } } else { gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); @@ -108,7 +107,7 @@ HOST_AND_DEVICE void ComputeGradientsElement( gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u]; } else { gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); - } + } } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { if (idx_b_t_up1 != -1) { gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1]; @@ -116,7 +115,7 @@ HOST_AND_DEVICE void ComputeGradientsElement( gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); } } else { - gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); + gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); } gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]); } @@ -129,5 +128,5 @@ HOST_AND_DEVICE void ComputeGradientsElement( } } -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/math.cuh b/torchaudio/csrc/rnnt/gpu/math.cuh index a2eaabff93..643fa98300 100644 --- a/torchaudio/csrc/rnnt/gpu/math.cuh +++ b/torchaudio/csrc/rnnt/gpu/math.cuh @@ -4,7 +4,7 @@ #include -#endif // USE_CUDA +#endif // USE_CUDA #include @@ -15,14 +15,18 @@ namespace math { template FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { - if (x > y) return x; - else return y; + if (x > y) + return x; + else + return y; } template FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { - if (x > y) return y; - else return x; + if (x > y) + return y; + else + return x; } // log_sum_exp @@ -31,11 +35,14 @@ FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); template <> FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { - if (y > x) { return y + log1pf(expf(x - y)); } - else { return x + log1pf(expf(y-x)); } + if (y > x) { + return y + log1pf(expf(x - y)); + } else { + return x + log1pf(expf(y - x)); + } } -} +} // namespace math -} // namespace rnnt -} // namespace torchaudio +} // namespace rnnt +} // namespace torchaudio From 7865319cb5e1a89c88bab2a726800b4f7e8353f1 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Tue, 4 May 2021 15:29:34 +0000 Subject: [PATCH 3/5] resolve csrc cmake warnings --- torchaudio/csrc/rnnt/gpu/compute.cu | 33 +++++++++++----------- torchaudio/csrc/rnnt/gpu/compute_alphas.cu | 14 ++++----- torchaudio/csrc/rnnt/gpu/compute_betas.cu | 16 +++++------ torchaudio/csrc/rnnt/gpu/kernels.h | 3 +- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index cb12c369ad..5e7d794cd7 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -57,39 +57,38 @@ std::tuple> compute( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); - switch (logits.type().scalarType()) { + switch (logits.scalar_type()) { case torch::ScalarType::Float: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), /*gradients=*/ - (gradients == c10::nullopt) ? nullptr : gradients->data()); + (gradients == c10::nullopt) ? nullptr : gradients->data_ptr()); break; } case torch::ScalarType::Half: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), /*gradients=*/ - (gradients == c10::nullopt) ? nullptr : gradients->data()); + (gradients == c10::nullopt) ? nullptr + : gradients->data_ptr()); break; } default: { - LOG(ERROR) << "unsupported logits.type().scalarType() = " - << logits.type().scalarType(); break; } }; diff --git a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu index dc101b0ecb..45433129f5 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu @@ -47,20 +47,20 @@ torch::Tensor compute_alphas( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*alphas=*/alphas.data()); + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*alphas=*/alphas.data_ptr()); return alphas; } diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu index f8f85337db..03d0385f2c 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_betas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -51,21 +51,21 @@ torch::Tensor compute_betas( Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data(), + /*dtype_data=*/float_workspace.data_ptr(), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data(), + /*int_data=*/int_workspace.data_ptr(), /*int_size=*/int_workspace.numel()); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data(), - /*targets=*/targets.data(), - /*src_lengths=*/src_lengths.data(), - /*tgt_lengths=*/tgt_lengths.data(), - /*costs=*/costs.data(), - /*betas=*/betas.data()); + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*src_lengths=*/src_lengths.data_ptr(), + /*tgt_lengths=*/tgt_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*betas=*/betas.data_ptr()); return betas; } diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h index 97093a34c0..db8bb5092b 100644 --- a/torchaudio/csrc/rnnt/gpu/kernels.h +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -56,12 +56,11 @@ HOST_AND_DEVICE void ComputeGradientsElement( Indexer2D idxr2(maxU - 1); - int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u, idx_b_tp1_up1; + int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u; Indexer3D idxr3(maxT, maxU); idx_b_t_u = idxr3(bTgt, t, u); idx_b_t_up1 = idxr3(bTgt, t, u + 1); idx_b_tp1_u = idxr3(bTgt, t + 1, u); - idx_b_tp1_up1 = idxr3(bTgt, t + 1, u + 1); if (idx_b_t_u == -1) { return; From 25c286eb3596b2c032796bc0b4b6b26a54b3a7bb Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Mon, 3 May 2021 22:30:07 +0000 Subject: [PATCH 4/5] minimal build --- CMakeLists.txt | 5 +++++ build_tools/setup_helpers/extension.py | 2 ++ torchaudio/csrc/CMakeLists.txt | 15 +++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 96360532f2..35b148c058 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,11 @@ option(BUILD_KALDI "Build kaldi statically" ON) option(BUILD_TRANSDUCER "Enable transducer" OFF) option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) +option(USE_CUDA "Enable CUDA support" OFF) + +if(USE_CUDA) + enable_language(CUDA) +endif() find_package(Torch REQUIRED) diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 430ff168c3..91002fa0e9 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -38,6 +38,7 @@ def _get_build(var, default=False): _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True) _BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER") _USE_ROCM = _get_build("USE_ROCM") +_USE_CUDA = torch.cuda.is_available() def get_ext_modules(): @@ -76,6 +77,7 @@ def build_extension(self, ext): "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_LIBTORCHAUDIO:BOOL=OFF", f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", + f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", ] build_args = [ '--target', 'install' diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 79bad4047f..ebf577eb64 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -20,6 +20,17 @@ if(BUILD_TRANSDUCER) rnnt/compute_betas.cpp rnnt/compute.cpp ) + + if (USE_CUDA) + set( + CUDA_TRANSDUCER_SOURCES + rnnt/gpu/compute_alphas.cu + rnnt/gpu/compute_betas.cu + rnnt/gpu/compute.cu + ) + list(APPEND TRANSDUCER_SOURCES ${CUDA_TRANSDUCER_SOURCES}) + endif() + list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES}) endif() @@ -105,6 +116,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI) endif() + if (USE_CUDA) + target_compile_definitions(_torchaudio PRIVATE USE_CUDA) + endif() + target_include_directories( _torchaudio PRIVATE From 5bc42ff5023c2849403c24cc607314466b16ccd4 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 6 May 2021 13:31:51 +0000 Subject: [PATCH 5/5] remove THC dependency --- torchaudio/csrc/rnnt/gpu/compute.cu | 2 +- torchaudio/csrc/rnnt/gpu/compute_alphas.cu | 2 +- torchaudio/csrc/rnnt/gpu/compute_betas.cu | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index 5e7d794cd7..c0f67946df 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu index 45433129f5..f24c6a5b0f 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu index 03d0385f2c..a225c9cf42 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_betas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -1,4 +1,4 @@ -#include +#include #include #include