diff --git a/docs/ops.md b/docs/ops.md index 226cd935d698a..6fc976ac7b452 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,7 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -42,7 +42,7 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | -| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -84,7 +84,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -111,6 +111,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | -| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/CUDA.csv b/docs/ops/CUDA.csv index 71e47977e31d1..48d0e4ac664df 100644 --- a/docs/ops/CUDA.csv +++ b/docs/ops/CUDA.csv @@ -29,6 +29,14 @@ "CUDA0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" "CUDA0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" "CUDA0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" "CUDA0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" "CUDA0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" @@ -59,6 +67,14 @@ "CUDA0","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" "CUDA0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","TRUNC","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" "CUDA0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" "CUDA0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" @@ -89,6 +105,14 @@ "CUDA0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" "CUDA0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" "CUDA0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CUDA" "CUDA0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" "CUDA0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" @@ -118,6 +142,14 @@ "CUDA0","EXP","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" "CUDA0","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" +"CUDA0","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","CUDA" +"CUDA0","TRUNC","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","CUDA" "CUDA0","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CUDA" "CUDA0","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CUDA" diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d948b00cc7f30..3944ff84744e7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1104,6 +1104,38 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // GELU using erf (error function) when possible // some backends may fallback to approximation based on Abramowitz and Stegun formula GGML_API struct ggml_tensor * ggml_gelu_erf( @@ -2154,6 +2186,18 @@ extern "C" { int p1, int p2, int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ); GGML_API struct ggml_tensor * ggml_pad_ext( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db14c514..0d22935b462b0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2345,6 +2345,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); + case GGML_UNARY_OP_FLOOR: + ggml_cuda_op_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_cuda_op_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_cuda_op_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_cuda_op_trunc(ctx, dst); break; default: return false; @@ -3357,6 +3368,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3c564566a51ff..6b5fe1ff6cea2 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -88,6 +88,22 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } +static __device__ __forceinline__ float op_floor(float x) { + return floorf(x); +} + +static __device__ __forceinline__ float op_ceil(float x) { + return ceilf(x); +} + +static __device__ __forceinline__ float op_round(float x) { + return roundf(x); +} + +static __device__ __forceinline__ float op_trunc(float x) { + return truncf(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -204,6 +220,23 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 8e7644fcd9a48..8f28f89ab7916 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -75,3 +75,11 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 810995d0cbf74..36d681195017d 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -1241,4 +1241,4 @@ void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_trunc(ctx, dst); -} +} \ No newline at end of file diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 86f1c31afd7a6..fd778ff82887c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6454,6 +6454,15 @@ static void ggml_compute_backward( ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); } } break; + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, ggml_new_f32(ctx, 0.0f), src0)); + } + } break; + default: { fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));