diff --git a/crates/spirv-std/macros/src/lib.rs b/crates/spirv-std/macros/src/lib.rs index d8ecf7b0cf..424c9ede5d 100644 --- a/crates/spirv-std/macros/src/lib.rs +++ b/crates/spirv-std/macros/src/lib.rs @@ -74,6 +74,7 @@ mod debug_printf; mod image; mod sample_param_permutations; +mod scalar_or_vector_composite; use crate::debug_printf::{DebugPrintfInput, debug_printf_inner}; use proc_macro::TokenStream; @@ -311,3 +312,10 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream { pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream { sample_param_permutations::gen_sample_param_permutations(item) } + +#[proc_macro_derive(ScalarOrVectorComposite)] +pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream { + scalar_or_vector_composite::derive(item.into()) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/crates/spirv-std/macros/src/scalar_or_vector_composite.rs b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs new file mode 100644 index 0000000000..f8b47c68fc --- /dev/null +++ b/crates/spirv-std/macros/src/scalar_or_vector_composite.rs @@ -0,0 +1,90 @@ +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::punctuated::Punctuated; +use syn::{ + Data, DataStruct, DataUnion, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, GenericParam, + Token, +}; + +pub fn derive(item: TokenStream) -> syn::Result { + // Whenever we'll properly resolve the crate symbol, replace this. + let spirv_std = quote!(spirv_std); + + // Defer all validation to our codegen backend. Rather than erroring here, emit garbage. + let item = syn::parse2::(item)?; + let content = match &item.data { + Data::Enum(_) => derive_enum(&spirv_std, &item), + Data::Struct(data) => derive_struct(&spirv_std, data), + Data::Union(DataUnion { union_token, .. }) => { + Err(syn::Error::new_spanned(union_token, "Union not supported")) + } + }?; + + let ident = &item.ident; + let gens = &item.generics.params; + let gen_refs = &item + .generics + .params + .iter() + .map(|p| match p { + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(p) => p.ident.to_token_stream(), + GenericParam::Const(p) => p.ident.to_token_stream(), + }) + .collect::>(); + let where_clause = &item.generics.where_clause; + + Ok(quote! { + impl<#gens> #spirv_std::ScalarOrVectorComposite for #ident<#gen_refs> #where_clause { + #[inline] + fn transform(self, f: &mut F) -> Self { + #content + } + } + }) +} + +pub fn derive_struct(spirv_std: &TokenStream, data: &DataStruct) -> syn::Result { + Ok(match &data.fields { + Fields::Named(FieldsNamed { named, .. }) => { + let content = named + .iter() + .map(|f| { + let ident = &f.ident; + quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f)) + }) + .collect::>(); + quote!(Self { #content }) + } + Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { + let content = (0..unnamed.len()) + .map(|i| { + let i = syn::Index::from(i); + quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f)) + }) + .collect::>(); + quote!(Self(#content)) + } + Fields::Unit => quote!(Self), + }) +} + +pub fn derive_enum(spirv_std: &TokenStream, item: &DeriveInput) -> syn::Result { + let mut attributes = item.attrs.iter().filter(|a| a.path().is_ident("repr")); + let repr = match (attributes.next(), attributes.next()) { + (None, _) => Err(syn::Error::new_spanned( + item, + "Missing #[repr(...)] attribute", + )), + (Some(repr), None) => Ok(repr), + (Some(_), Some(_)) => Err(syn::Error::new_spanned( + item, + "Multiple #[repr(...)] attributes found", + )), + }?; + let prim = &repr.meta.require_list()?.tokens; + Ok(quote! { + #spirv_std::assert_is_integer::<#prim>(); + >::from(#spirv_std::ScalarOrVectorComposite::transform(>::into(self), f)) + }) +} diff --git a/crates/spirv-std/src/arch/subgroup.rs b/crates/spirv-std/src/arch/subgroup.rs index a9690d4190..5985a084c3 100644 --- a/crates/spirv-std/src/arch/subgroup.rs +++ b/crates/spirv-std/src/arch/subgroup.rs @@ -1,11 +1,12 @@ -use crate::ScalarOrVector; #[cfg(target_arch = "spirv")] -use crate::arch::barrier; +use crate::ScalarOrVectorTransform; #[cfg(target_arch = "spirv")] -use crate::memory::{Scope, Semantics}; -use crate::{Float, Integer, SignedInteger, UnsignedInteger}; +use crate::arch::{asm, barrier}; #[cfg(target_arch = "spirv")] -use core::arch::asm; +use crate::memory::{Scope, Semantics}; +use crate::{ + Float, Integer, ScalarOrVector, ScalarOrVectorComposite, SignedInteger, UnsignedInteger, +}; #[cfg(target_arch = "spirv")] const SUBGROUP: u32 = Scope::Subgroup as u32; @@ -243,24 +244,35 @@ pub fn subgroup_any(predicate: bool) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformAllEqual")] #[inline] -pub fn subgroup_all_equal(value: T) -> bool { - let mut result = false; +pub fn subgroup_all_equal(value: T) -> bool { + struct Transform(bool); - unsafe { - asm! { - "%bool = OpTypeBool", - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformAllEqual %bool %subgroup %value", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = false; + unsafe { + asm! { + "%bool = OpTypeBool", + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformAllEqual %bool %subgroup %value", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + self.0 &= result; + value } } - result + let mut transform = Transform(true); + // ignore returned value + value.transform(&mut transform); + transform.0 } /// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group. @@ -287,25 +299,34 @@ pub fn subgroup_all_equal(value: T) -> bool { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcast")] #[inline] -pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { - let mut result = T::default(); +pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { + struct Transform { + id: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%id = OpLoad _ {id}", - "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = in(reg) &id, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%id = OpLoad _ {id}", + "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = in(reg) &self.id, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { id }) } /// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group. @@ -330,24 +351,31 @@ pub unsafe fn subgroup_broadcast(value: T, id: u32) -> T { #[doc(alias = "OpGroupNonUniformBroadcast")] #[inline] pub unsafe fn subgroup_broadcast_const(value: T) -> T { - let mut result = T::default(); + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%id = OpConstant %u32 {id}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = const ID, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%id = OpConstant %u32 {id}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformBroadcast _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = const ID, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform::) } /// Result is the `value` of the invocation from the active invocation with the lowest id in the group to all active invocations in the group. @@ -362,23 +390,30 @@ pub unsafe fn subgroup_broadcast_const(value: #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformBroadcastFirst")] #[inline] -pub fn subgroup_broadcast_first(value: T) -> T { - let mut result = T::default(); +pub fn subgroup_broadcast_first(value: T) -> T { + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformBroadcastFirst _ %subgroup %value", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformBroadcastFirst _ %subgroup %value", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform) } /// Result is a bitfield value combining the `predicate` value from all invocations in the group that execute the same dynamic instance of this instruction. The bit is set to one if the corresponding invocation is active and the `predicate` for that invocation evaluated to true; otherwise, it is set to zero. @@ -637,25 +672,34 @@ pub fn subgroup_ballot_find_msb(value: SubgroupMask) -> u32 { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffle")] #[inline] -pub fn subgroup_shuffle(value: T, id: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle(value: T, id: u32) -> T { + struct Transform { + id: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%id = OpLoad _ {id}", - "%result = OpGroupNonUniformShuffle _ %subgroup %value %id", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - id = in(reg) &id, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%id = OpLoad _ {id}", + "%result = OpGroupNonUniformShuffle _ %subgroup %value %id", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + id = in(reg) &self.id, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { id }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group xor’ed with Mask. @@ -678,25 +722,34 @@ pub fn subgroup_shuffle(value: T, id: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleXor")] #[inline] -pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { + struct Transform { + mask: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%mask = OpLoad _ {mask}", - "%result = OpGroupNonUniformShuffleXor _ %subgroup %value %mask", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - mask = in(reg) &mask, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%mask = OpLoad _ {mask}", + "%result = OpGroupNonUniformShuffleXor _ %subgroup %value %mask", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + mask = in(reg) &self.mask, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { mask }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group - Delta. @@ -719,25 +772,34 @@ pub fn subgroup_shuffle_xor(value: T, mask: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleUp")] #[inline] -pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { + struct Transform { + delta: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%delta = OpLoad _ {delta}", - "%result = OpGroupNonUniformShuffleUp _ %subgroup %value %delta", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - delta = in(reg) &delta, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%delta = OpLoad _ {delta}", + "%result = OpGroupNonUniformShuffleUp _ %subgroup %value %delta", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + delta = in(reg) &self.delta, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { delta }) } /// Result is the `value` of the invocation identified by the current invocation’s id within the group + Delta. @@ -760,25 +822,34 @@ pub fn subgroup_shuffle_up(value: T, delta: u32) -> T { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformShuffleDown")] #[inline] -pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { - let mut result = T::default(); +pub fn subgroup_shuffle_down(value: T, delta: u32) -> T { + struct Transform { + delta: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%delta = OpLoad _ {delta}", - "%result = OpGroupNonUniformShuffleDown _ %subgroup %value %delta", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - delta = in(reg) &delta, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%delta = OpLoad _ {delta}", + "%result = OpGroupNonUniformShuffleDown _ %subgroup %value %delta", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + delta = in(reg) &self.delta, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { delta }) } macro_rules! macro_subgroup_op { @@ -1387,25 +1458,34 @@ Requires Capability `GroupNonUniformArithmetic` and `GroupNonUniformClustered`. #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadBroadcast")] #[inline] -pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { - let mut result = T::default(); +pub fn subgroup_quad_broadcast(value: T, index: u32) -> T { + struct Transform { + index: u32, + } - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%value = OpLoad _ {value}", - "%index = OpLoad _ {index}", - "%result = OpGroupNonUniformQuadBroadcast _ %subgroup %value %index", - "OpStore {result} %result", - subgroup = const SUBGROUP, - value = in(reg) &value, - index = in(reg) &index, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%value = OpLoad _ {value}", + "%index = OpLoad _ {index}", + "%result = OpGroupNonUniformQuadBroadcast _ %subgroup %value %index", + "OpStore {result} %result", + subgroup = const SUBGROUP, + value = in(reg) &value, + index = in(reg) &self.index, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform { index }) } /// Direction is the kind of swap to perform. @@ -1470,23 +1550,30 @@ pub enum QuadDirection { #[spirv_std_macros::gpu_only] #[doc(alias = "OpGroupNonUniformQuadSwap")] #[inline] -pub fn subgroup_quad_swap(value: T) -> T { - let mut result = T::default(); +pub fn subgroup_quad_swap(value: T) -> T { + struct Transform; - unsafe { - asm! { - "%u32 = OpTypeInt 32 0", - "%subgroup = OpConstant %u32 {subgroup}", - "%direction = OpConstant %u32 {direction}", - "%value = OpLoad _ {value}", - "%result = OpGroupNonUniformQuadSwap _ %subgroup %value %direction", - "OpStore {result} %result", - subgroup = const SUBGROUP, - direction = const DIRECTION, - value = in(reg) &value, - result = in(reg) &mut result, + impl ScalarOrVectorTransform for Transform { + #[inline] + fn transform(&mut self, value: T) -> T { + let mut result = T::default(); + unsafe { + asm! { + "%u32 = OpTypeInt 32 0", + "%subgroup = OpConstant %u32 {subgroup}", + "%direction = OpConstant %u32 {direction}", + "%value = OpLoad _ {value}", + "%result = OpGroupNonUniformQuadSwap _ %subgroup %value %direction", + "OpStore {result} %result", + subgroup = const SUBGROUP, + direction = const DIRECTION, + value = in(reg) &value, + result = in(reg) &mut result, + } + } + result } } - result + value.transform(&mut Transform::) } diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 2c85dc9af0..2fdcd5610d 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -87,6 +87,7 @@ /// Public re-export of the `spirv-std-macros` crate. #[macro_use] pub extern crate spirv_std_macros as macros; +pub use macros::ScalarOrVectorComposite; pub use macros::spirv; pub use macros::{debug_printf, debug_printfln}; diff --git a/crates/spirv-std/src/scalar.rs b/crates/spirv-std/src/scalar.rs index b4774ea861..52d4b365f2 100644 --- a/crates/spirv-std/src/scalar.rs +++ b/crates/spirv-std/src/scalar.rs @@ -1,7 +1,7 @@ //! Traits related to scalars. -use crate::ScalarOrVector; use crate::sealed::Sealed; +use crate::{ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; use core::num::NonZeroUsize; /// Abstract trait representing a SPIR-V scalar type, which includes: @@ -61,7 +61,13 @@ pub unsafe trait Float: num_traits::Float + Number { macro_rules! impl_scalar { (impl Scalar for $ty:ty;) => { impl Sealed for $ty {} - unsafe impl ScalarOrVector for $ty { + impl ScalarOrVectorComposite for $ty { + #[inline] + fn transform(self, f: &mut F) -> Self { + f.transform_scalar(self) + } + } + unsafe impl ScalarOrVector for $ty { type Scalar = Self; const N: NonZeroUsize = NonZeroUsize::new(1).unwrap(); } @@ -111,3 +117,7 @@ impl_scalar! { impl Float for f64; impl Scalar for bool; } + +/// used by `ScalarOrVector` derive when working with enums +#[inline] +pub fn assert_is_integer() {} diff --git a/crates/spirv-std/src/scalar_or_vector.rs b/crates/spirv-std/src/scalar_or_vector.rs index 87b0073241..89a76f4b54 100644 --- a/crates/spirv-std/src/scalar_or_vector.rs +++ b/crates/spirv-std/src/scalar_or_vector.rs @@ -1,4 +1,4 @@ -use crate::Scalar; +use crate::{Scalar, Vector}; use core::num::NonZeroUsize; pub(crate) mod sealed { @@ -10,13 +10,78 @@ pub(crate) mod sealed { /// Abstract trait representing either a [`Scalar`] or [`Vector`] type. /// /// # Safety -/// Your type must also implement [`Scalar`] or [`Vector`], see their safety sections as well. -/// -/// [`Vector`]: crate::Vector -pub unsafe trait ScalarOrVector: Copy + Default + Send + Sync + 'static { +/// Implementing this trait on non-scalar or non-vector types may break assumptions about other +/// unsafe code, and should not be done. +pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default { /// Either the scalar component type of the vector or the scalar itself. type Scalar: Scalar; /// The dimension of the vector, or 1 if it is a scalar const N: NonZeroUsize; } + +/// A `VectorOrScalarComposite` is a type that is either +/// * a [`Scalar`] +/// * a [`Vector`] +/// * an array of `VectorOrScalarComposite` +/// * a struct where all members are `VectorOrScalarComposite` +/// * an enum with a `repr` that is a [`Scalar`] +/// +/// By calling [`Self::transform`] you can visit all the individual [`Scalar`] and [`Vector`] values this composite is +/// build out of and transform them into some other value. This is particularly useful for subgroup intrinsics sending +/// data to other threads. +/// +/// To derive `#[derive(VectorOrScalarComposite)]` on a struct, all members must also implement +/// `VectorOrScalarComposite`. +/// +/// To derive it on an enum, the enum must implement `From` and `Into` where `N` is defined by the `#[repr(N)]` +/// attribute on the enum and is an [`Integer`], like `u32`. +/// Note that some [safe subgroup operations] may return an "undefined result", so your `From` must gracefully handle +/// arbitrary bit patterns being passed to it. While panicking is legal, it is discouraged as it may result in +/// unexpected control flow. +/// To implement these conversion traits, we recommend [`FromPrimitive`] and [`IntoPrimitive`] from the [`num_enum`] +/// crate. [`FromPrimitive`] requires that either the enum is exhaustive, or you provide it with a variant to default +/// to, by either implementing [`Default`] or marking a variant with `#[num_enum(default)]`. Note to disable default +/// features on the [`num_enum`] crate, or it won't compile on SPIR-V. +/// +/// [`Integer`]: crate::Integer +/// [subgroup operations]: crate::arch::subgroup_shuffle +/// [`FromPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.FromPrimitive.html +/// [`IntoPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.IntoPrimitive.html +/// [`num_enum`]: https://crates.io/crates/num_enum +pub trait ScalarOrVectorComposite: Copy + Send + Sync + 'static { + /// Transform the individual [`Scalar`] and [`Vector`] values of this type to a different value. + /// + /// See [`Self`] for more detail. + fn transform(self, f: &mut F) -> Self; +} + +/// A transform operation for [`ScalarOrVectorComposite::transform`] +pub trait ScalarOrVectorTransform { + /// transform a [`ScalarOrVector`] + fn transform(&mut self, value: T) -> T; + + /// transform a [`Scalar`], defaults to [`self.transform`] + #[inline] + fn transform_scalar(&mut self, value: T) -> T { + self.transform(value) + } + + /// transform a [`Vector`], defaults to [`self.transform`] + #[inline] + fn transform_vector, S: Scalar, const N: usize>(&mut self, value: V) -> V { + self.transform(value) + } +} + +/// `Default` is unfortunately necessary until rust-gpu improves +impl ScalarOrVectorComposite for [T; N] { + #[inline] + fn transform(self, f: &mut F) -> Self { + let mut out = [T::default(); N]; + for i in 0..N { + out[i] = self[i].transform(f); + } + out + } +} diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index 0389424df0..c5464adfc2 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -1,7 +1,7 @@ //! Traits related to vectors. use crate::sealed::Sealed; -use crate::{Scalar, ScalarOrVector}; +use crate::{Scalar, ScalarOrVector, ScalarOrVectorComposite, ScalarOrVectorTransform}; use core::num::NonZeroUsize; use glam::{Vec3Swizzles, Vec4Swizzles}; @@ -57,6 +57,12 @@ macro_rules! impl_vector { ($($ty:ty: [$scalar:ty; $n:literal];)+) => { $( impl Sealed for $ty {} + impl ScalarOrVectorComposite for $ty { + #[inline] + fn transform(self, f: &mut F) -> Self { + f.transform_vector(self) + } + } unsafe impl ScalarOrVector for $ty { type Scalar = $scalar; const N: NonZeroUsize = NonZeroUsize::new($n).unwrap(); diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr index c292b79934..bc6d3a980f 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be at least 1 - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr index 61d066c3fc..e254fb228b 100644 --- a/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr +++ b/tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr @@ -1,5 +1,5 @@ error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2 - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. @@ -13,7 +13,7 @@ LL | | "); = note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info) note: erroneous constant encountered - --> $SPIRV_STD_SRC/arch/subgroup.rs:868:1 + --> $SPIRV_STD_SRC/arch/subgroup.rs:939:1 | LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r" LL | | An integer add group operation of all `value` operands contributed by active invocations in the group. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs new file mode 100644 index 0000000000..ca4829f521 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite.rs @@ -0,0 +1,54 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+GroupNonUniformShuffle,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model +// normalize-stderr-test "OpLine .*\n" -> "" +// ignore-vulkan1.0 +// ignore-vulkan1.1 +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-spv1.4 + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3, +) { + unsafe { + let my_struct = MyStruct { + a: 1., + b: inv_id, + c: Nested(-42), + d: Zst, + }; + + let mut out = UVec3::ZERO; + // before spv1.5 / vulkan1.2, this id = 19 must be a constant + out += subgroup_broadcast(my_struct, 19).b; + out += subgroup_broadcast_first(my_struct).b; + out += subgroup_shuffle(my_struct, 2).b; + out += subgroup_shuffle_xor(my_struct, 4).b; + out += subgroup_shuffle_up(my_struct, 5).b; + out += subgroup_shuffle_down(my_struct, 7).b; + *output = out; + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs new file mode 100644 index 0000000000..2c1c12f9aa --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.rs @@ -0,0 +1,46 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformVote,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_all_equals::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +/// this should be 3 `subgroup_all_equal` instructions, with all calls inlined +fn disassembly(my_struct: MyStruct) -> bool { + subgroup_all_equal(my_struct) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut u32, +) { + unsafe { + let my_struct = MyStruct { + a: inv_id.x as f32, + b: inv_id, + c: Nested(5i32 - inv_id.x as i32), + d: Zst, + }; + + let bool = disassembly(my_struct); + *output = u32::from(bool); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr new file mode 100644 index 0000000000..d0167e9bed --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_all_equals.stderr @@ -0,0 +1,15 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %5 +%6 = OpLabel +%8 = OpCompositeExtract %9 %4 0 +%11 = OpGroupNonUniformAllEqual %2 %12 %8 +%13 = OpLogicalAnd %2 %14 %11 +%15 = OpCompositeExtract %16 %4 1 +%17 = OpGroupNonUniformAllEqual %2 %12 %15 +%18 = OpLogicalAnd %2 %13 %17 +%19 = OpCompositeExtract %20 %4 2 +%21 = OpGroupNonUniformAllEqual %2 %12 %19 +%22 = OpLogicalAnd %2 %18 %21 +OpNoLine +OpReturnValue %22 +OpFunctionEnd diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs new file mode 100644 index 0000000000..ee4d039e72 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.rs @@ -0,0 +1,53 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_enum::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[repr(u32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum MyEnum { + #[default] + A, + B, + C, +} + +impl From for MyEnum { + #[inline] + fn from(value: u32) -> Self { + match value { + 0 => Self::A, + 1 => Self::B, + 2 => Self::C, + _ => Self::default(), + } + } +} + +impl From for u32 { + #[inline] + fn from(value: MyEnum) -> Self { + value as u32 + } +} + +/// this should be 3 `subgroup_shuffle` instructions, with all calls inlined +fn disassembly(my_struct: MyEnum, id: u32) -> MyEnum { + subgroup_shuffle(my_struct, id) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyEnum, +) { + unsafe { + let my_enum = MyEnum::from(inv_id.x % 3); + *output = disassembly(my_enum, 5); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr new file mode 100644 index 0000000000..091689e0fc --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum.stderr @@ -0,0 +1,20 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %2 +%6 = OpLabel +%8 = OpGroupNonUniformShuffle %2 %9 %4 %5 +OpNoLine +OpSelectionMerge %10 None +OpSwitch %8 %11 0 %12 1 %13 2 %14 +%11 = OpLabel +OpBranch %10 +%12 = OpLabel +OpBranch %10 +%13 = OpLabel +OpBranch %10 +%14 = OpLabel +OpBranch %10 +%10 = OpLabel +%15 = OpPhi %2 %16 %11 %16 %12 %17 %13 %18 %14 +OpReturnValue %15 +OpFunctionEnd diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs new file mode 100644 index 0000000000..2706891ec2 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.rs @@ -0,0 +1,89 @@ +// build-fail +// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/" +// normalize-stderr-test "\.rs:\d+:\d+" -> ".rs:" +// normalize-stderr-test "(\n)\d* *([ -])([\|\+\-\=])" -> "$1 $2$3" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +macro_rules! enum_repr_from { + ($ident:ident, $repr:ty) => { + impl From<$repr> for $ident { + #[inline] + fn from(value: $repr) -> Self { + match value { + 0 => Self::A, + 1 => Self::B, + 2 => Self::C, + _ => Self::default(), + } + } + } + + impl From<$ident> for $repr { + #[inline] + fn from(value: $ident) -> Self { + value as $repr + } + } + }; +} + +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum NoRepr { + #[default] + A, + B, + C, +} + +#[repr(u32)] +#[repr(u16)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum TwoRepr { + #[default] + A, + B, + C, +} + +#[repr(C)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum CRepr { + #[default] + A, + B, + C, +} + +#[repr(i32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum NoFrom { + #[default] + A, + B, + C, +} + +#[repr(i32)] +#[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +pub enum WrongFrom { + #[default] + A, + B, + C, +} + +enum_repr_from!(WrongFrom, u32); + +#[repr(i32)] +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub enum NoDefault { + A, + B, + C, +} + +enum_repr_from!(NoDefault, i32); diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr new file mode 100644 index 0000000000..8665751b2a --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_enum_err.stderr @@ -0,0 +1,136 @@ +error: Missing #[repr(...)] attribute + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | / pub enum NoRepr { +LL | | #[default] +LL | | A, +LL | | B, +LL | | C, +LL | | } + | |_^ + +error: Multiple #[repr(...)] attributes found + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | / #[repr(u32)] +LL | | #[repr(u16)] +LL | | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] +LL | | pub enum TwoRepr { +... | +LL | | C, +LL | | } + | |_^ + +error[E0412]: cannot find type `C` in this scope + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[repr(C)] + | ^ +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ----------------------- similarly named type parameter `F` defined here + | +help: there is an enum variant `crate::CRepr::C` and 6 others; try using the variant's enum + | +LL - #[repr(C)] +LL + #[repr(crate::CRepr)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoDefault)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoFrom)] + | +LL - #[repr(C)] +LL + #[repr(crate::NoRepr)] + | + and 2 other candidates +help: a type parameter with a similar name exists + | +LL - #[repr(C)] +LL + #[repr(F)] + | + +error[E0566]: conflicting representation hints + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[repr(u32)] + | ^^^ +LL | #[repr(u16)] + | ^^^ + | + = warning: this was previously accepted by the compiler but is being phased out; it will become a hard error in a future release! + = note: for more information, see issue #68585 + = note: `#[deny(conflicting_repr_hints)]` on by default + +error[E0277]: the trait bound `NoFrom: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `NoFrom` + | + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `i32: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` + | + = help: the following other types implement trait `From`: + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + = note: required for `NoFrom` to implement `Into` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `WrongFrom: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `WrongFrom` + | + = help: the trait `From` is not implemented for `WrongFrom` + but trait `From` is implemented for it + = help: for that trait implementation, expected `u32`, found `i32` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `i32: From` is not satisfied + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | #[derive(Copy, Clone, Default, ScalarOrVectorComposite)] + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `From` is not implemented for `i32` + | + = help: the following other types implement trait `From`: + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + `i32` implements `From` + = note: required for `WrongFrom` to implement `Into` + = note: this error originates in the derive macro `ScalarOrVectorComposite` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0599]: no variant or associated item named `default` found for enum `NoDefault` in the current scope + --> $DIR/subgroup_composite_enum_err.rs: + | +LL | _ => Self::default(), + | ^^^^^^^ variant or associated item not found in `NoDefault` +... +LL | pub enum NoDefault { + | ------------------ variant or associated item `default` not found for this enum +... +LL | enum_repr_from!(NoDefault, i32); + | ------------------------------- in this macro invocation + | + = help: items from traits can only be used if the trait is implemented and in scope + = note: the following trait defines an item `default`, perhaps you need to implement it: + candidate #1: `Default` + = note: this error originates in the macro `enum_repr_from` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: aborting due to 9 previous errors + +Some errors have detailed explanations: E0277, E0412, E0566, E0599. +For more information about an error, try `rustc --explain E0277`. diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs new file mode 100644 index 0000000000..1009fb74ca --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.rs @@ -0,0 +1,45 @@ +// build-pass +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_shuffle::disassembly +// normalize-stderr-test "OpLine .*\n" -> "" + +use glam::*; +use spirv_std::ScalarOrVectorComposite; +use spirv_std::arch::*; +use spirv_std::spirv; + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct MyStruct { + a: f32, + b: UVec3, + c: Nested, + d: Zst, +} + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Nested(i32); + +#[derive(Copy, Clone, ScalarOrVectorComposite)] +pub struct Zst; + +/// this should be 3 `subgroup_shuffle` instructions, with all calls inlined +fn disassembly(my_struct: MyStruct, id: u32) -> MyStruct { + subgroup_shuffle(my_struct, id) +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(local_invocation_index)] inv_id: UVec3, + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyStruct, +) { + unsafe { + let my_struct = MyStruct { + a: inv_id.x as f32, + b: inv_id, + c: Nested(5i32 - inv_id.x as i32), + d: Zst, + }; + + *output = disassembly(my_struct, 5); + } +} diff --git a/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr new file mode 100644 index 0000000000..0127324087 --- /dev/null +++ b/tests/compiletests/ui/arch/subgroup/subgroup_composite_shuffle.stderr @@ -0,0 +1,16 @@ +%1 = OpFunction %2 None %3 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %6 +%7 = OpLabel +%9 = OpCompositeExtract %10 %4 0 +%12 = OpGroupNonUniformShuffle %10 %13 %9 %5 +%14 = OpCompositeExtract %15 %4 1 +%16 = OpGroupNonUniformShuffle %15 %13 %14 %5 +%17 = OpCompositeExtract %18 %4 2 +%19 = OpGroupNonUniformShuffle %18 %13 %17 %5 +%20 = OpCompositeInsert %2 %12 %21 0 +%22 = OpCompositeInsert %2 %16 %20 1 +%23 = OpCompositeInsert %2 %19 %22 2 +OpNoLine +OpReturnValue %23 +OpFunctionEnd