From 291f80f3e79c3cbdedf2cee7b655dee51acf1f55 Mon Sep 17 00:00:00 2001 From: YasInvolved Date: Thu, 5 Jun 2025 17:33:12 +0200 Subject: [PATCH 1/2] add counting sort example --- examples/hlsl/counting_sort.hlsl | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 examples/hlsl/counting_sort.hlsl diff --git a/examples/hlsl/counting_sort.hlsl b/examples/hlsl/counting_sort.hlsl new file mode 100644 index 00000000000..abef6492bf8 --- /dev/null +++ b/examples/hlsl/counting_sort.hlsl @@ -0,0 +1,42 @@ +#include "nbl/builtin/hlsl/sort/counting.hlsl" + +[[vk::push_constant]] CountingPushData pushData; + +using DoublePtrAccessor = DoubleBdaAccessor; + +[numthreads(WorkgroupSize, 1, 1)] +void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID) +{ + sort::CountingParameters params; + params.dataElementCount = pushData.dataElementCount; + params.elementsPerWT = pushData.elementsPerWT; + params.minimum = pushData.minimum; + params.maximum = pushData.maximum; + + using Counter = sort::counting; + Counter counter = Counter::create(glsl::gl_WorkGroupID().x); + + const Ptr input_key_ptr = Ptr::create(pushData.inputKeyAddress); + const Ptr input_value_ptr = Ptr::create(pushData.inputValueAddress); + const Ptr histogram_ptr = Ptr::create(pushData.histogramAddress); + const Ptr output_key_ptr = Ptr::create(pushData.outputKeyAddress); + const Ptr output_value_ptr = Ptr::create(pushData.outputValueAddress); + + DoublePtrAccessor key_accessor = DoublePtrAccessor::create( + input_key_ptr, + output_key_ptr + ); + DoublePtrAccessor value_accessor = DoublePtrAccessor::create( + input_value_ptr, + output_value_ptr + ); + PtrAccessor histogram_accessor = PtrAccessor::create(histogram_ptr); + SharedAccessor shared_accessor; + counter.scatter( + key_accessor, + value_accessor, + histogram_accessor, + shared_accessor, + params + ); +} \ No newline at end of file From 122bf26757706b66c114ea02a1994d860cc5501e Mon Sep 17 00:00:00 2001 From: YasInvolved Date: Thu, 5 Jun 2025 18:43:36 +0200 Subject: [PATCH 2/2] manually include common.hlsl from example and add dummy data to make it compile --- examples/hlsl/counting_sort.hlsl | 82 +++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/examples/hlsl/counting_sort.hlsl b/examples/hlsl/counting_sort.hlsl index abef6492bf8..3911549ea6a 100644 --- a/examples/hlsl/counting_sort.hlsl +++ b/examples/hlsl/counting_sort.hlsl @@ -1,40 +1,86 @@ +// The entry point and target profile are needed to compile this example: +// -T ps_6_6 -E PSMain + #include "nbl/builtin/hlsl/sort/counting.hlsl" +#include "nbl/builtin/hlsl/bda/bda_accessor.hlsl" + +#define BucketCount 27 +#define WorkgroupSize 27 + +struct CountingPushData +{ + uint64_t inputKeyAddress; + uint64_t inputValueAddress; + uint64_t histogramAddress; + uint64_t outputKeyAddress; + uint64_t outputValueAddress; + uint32_t dataElementCount; + uint32_t elementsPerWT; + uint32_t minimum; + uint32_t maximum; +}; + +using namespace nbl::hlsl; + +using Ptr = bda::__ptr; +using PtrAccessor = BdaAccessor; + +groupshared uint32_t sdata[BucketCount]; + +struct SharedAccessor +{ + void get(const uint32_t index, NBL_REF_ARG(uint32_t) value) + { + value = sdata[index]; + } + + void set(const uint32_t index, const uint32_t value) + { + sdata[index] = value; + } + + uint32_t atomicAdd(const uint32_t index, const uint32_t value) + { + return glsl::atomicAdd(sdata[index], value); + } + + void workgroupExecutionAndMemoryBarrier() + { + glsl::barrier(); + } +}; + +uint32_t3 glsl::gl_WorkGroupSize() +{ + return uint32_t3(WorkgroupSize, 1, 1); +} [[vk::push_constant]] CountingPushData pushData; using DoublePtrAccessor = DoubleBdaAccessor; -[numthreads(WorkgroupSize, 1, 1)] +[[vk::push_constant]] CountingPushData pushData; + +[numthreads(WorkgroupSize,1,1)] void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID) { - sort::CountingParameters params; + sort::CountingParameters < uint32_t > params; params.dataElementCount = pushData.dataElementCount; params.elementsPerWT = pushData.elementsPerWT; params.minimum = pushData.minimum; params.maximum = pushData.maximum; - using Counter = sort::counting; + using Counter = sort::counting; Counter counter = Counter::create(glsl::gl_WorkGroupID().x); - const Ptr input_key_ptr = Ptr::create(pushData.inputKeyAddress); - const Ptr input_value_ptr = Ptr::create(pushData.inputValueAddress); + const Ptr input_ptr = Ptr::create(pushData.inputKeyAddress); const Ptr histogram_ptr = Ptr::create(pushData.histogramAddress); - const Ptr output_key_ptr = Ptr::create(pushData.outputKeyAddress); - const Ptr output_value_ptr = Ptr::create(pushData.outputValueAddress); - DoublePtrAccessor key_accessor = DoublePtrAccessor::create( - input_key_ptr, - output_key_ptr - ); - DoublePtrAccessor value_accessor = DoublePtrAccessor::create( - input_value_ptr, - output_value_ptr - ); + PtrAccessor input_accessor = PtrAccessor::create(input_ptr); PtrAccessor histogram_accessor = PtrAccessor::create(histogram_ptr); SharedAccessor shared_accessor; - counter.scatter( - key_accessor, - value_accessor, + counter.histogram( + input_accessor, histogram_accessor, shared_accessor, params