2323
2424class MyInt32Const ;
2525class MyFloatConst ;
26+ class MyConst ;
2627
2728using namespace sycl ;
2829
2930class KernelAAAi ;
3031class KernelBBBf ;
3132
32- int val = 10 ;
33+ int global_val = 10 ;
3334
3435// Fetch a value at runtime.
35- int get_value () { return val ; }
36+ int get_value () { return global_val ; }
3637
3738float foo (
3839 const cl::sycl::ONEAPI::experimental::spec_constant<float , MyFloatConst>
@@ -49,8 +50,22 @@ struct SCWrapper {
4950 cl::sycl::ONEAPI::experimental::spec_constant<int , class sc_name2 > SC2;
5051};
5152
53+ // MyKernel is used to test default constructor
54+ using AccT = sycl::accessor<int , 1 , sycl::access::mode::write>;
55+ using ScT = sycl::ONEAPI::experimental::spec_constant<int , MyConst>;
56+
57+ struct MyKernel {
58+ MyKernel (AccT &Acc) : Acc(Acc) {}
59+
60+ void setConst (ScT Sc) { this ->Sc = Sc; }
61+
62+ void operator ()() const { Acc[0 ] = Sc.get (); }
63+ AccT Acc;
64+ ScT Sc;
65+ };
66+
5267int main (int argc, char **argv) {
53- val = argc + 16 ;
68+ global_val = argc + 16 ;
5469
5570 cl::sycl::queue q (default_selector{}, [](exception_list l) {
5671 for (auto ep : l) {
@@ -68,10 +83,11 @@ int main(int argc, char **argv) {
6883
6984 std::cout << " Running on " << q.get_device ().get_info <info::device::name>()
7085 << " \n " ;
71- std::cout << " val = " << val << " \n " ;
86+ std::cout << " global_val = " << global_val << " \n " ;
7287 cl::sycl::program program1 (q.get_context ());
7388 cl::sycl::program program2 (q.get_context ());
7489 cl::sycl::program program3 (q.get_context ());
90+ cl::sycl::program program4 (q.get_context ());
7591
7692 int goldi = (int )get_value ();
7793 // TODO make this floating point once supported by the compiler
@@ -83,22 +99,30 @@ int main(int argc, char **argv) {
8399 cl::sycl::ONEAPI::experimental::spec_constant<float , MyFloatConst> f32 =
84100 program2.set_spec_constant <MyFloatConst>(goldf);
85101
102+ cl::sycl::ONEAPI::experimental::spec_constant<int , MyConst> sc =
103+ program4.set_spec_constant <MyConst>(goldi);
104+
86105 program1.build_with_kernel_type <KernelAAAi>();
87106 // Use an option (does not matter which exactly) to test different internal
88107 // SYCL RT execution path
89108 program2.build_with_kernel_type <KernelBBBf>(" -cl-fast-relaxed-math" );
90109
91110 SCWrapper W (program3);
92111 program3.build_with_kernel_type <class KernelWrappedSC >();
112+
113+ program4.build_with_kernel_type <MyKernel>();
114+
93115 int goldw = 6 ;
94116
95117 std::vector<int > veci (1 );
96118 std::vector<float > vecf (1 );
97119 std::vector<int > vecw (1 );
120+ std::vector<int > vec (1 );
98121 try {
99122 cl::sycl::buffer<int , 1 > bufi (veci.data (), veci.size ());
100123 cl::sycl::buffer<float , 1 > buff (vecf.data (), vecf.size ());
101124 cl::sycl::buffer<int , 1 > bufw (vecw.data (), vecw.size ());
125+ cl::sycl::buffer<int , 1 > buf (vec.data (), vec.size ());
102126
103127 q.submit ([&](cl::sycl::handler &cgh) {
104128 auto acci = bufi.get_access <cl::sycl::access::mode::write>(cgh);
@@ -123,6 +147,19 @@ int main(int argc, char **argv) {
123147 program3.get_kernel <KernelWrappedSC>(),
124148 [=]() { accw[0 ] = W.SC1 .get () + W.SC2 .get (); });
125149 });
150+ // Check spec_constant default construction with subsequent initialization
151+ q.submit ([&](cl::sycl::handler &cgh) {
152+ auto acc = buf.get_access <cl::sycl::access::mode::write>(cgh);
153+ // Specialization constants specification says:
154+ // cl::sycl::experimental::spec_constant is default constructible,
155+ // although the object is not considered initialized until the result of
156+ // the call to cl::sycl::program::set_spec_constant is assigned to it.
157+ MyKernel Kernel (acc); // default construct inside MyKernel instance
158+ Kernel.setConst (sc); // initialize to sc, returned by set_spec_constant
159+
160+ cgh.single_task <MyKernel>(program4.get_kernel <MyKernel>(), Kernel);
161+ });
162+
126163 } catch (cl::sycl::exception &e) {
127164 std::cout << " *** Exception caught: " << e.what () << " \n " ;
128165 return 1 ;
@@ -146,6 +183,12 @@ int main(int argc, char **argv) {
146183 std::cout << " *** ERROR: " << valw << " != " << goldw << " (gold)\n " ;
147184 passed = false ;
148185 }
186+ int val = vec[0 ];
187+
188+ if (val != goldi) {
189+ std::cout << " *** ERROR: " << val << " != " << goldi << " (gold)\n " ;
190+ passed = false ;
191+ }
149192 std::cout << (passed ? " passed\n " : " FAILED\n " );
150193 return passed ? 0 : 1 ;
151194}
0 commit comments