Skip to content

Commit 07d2ac7

Browse files
[ESIMD] Apply CRTP to simd_view_impl class (#4351)
CRTP - Curiously Recurring Template Pattern. The main goal of the change is that it would not be possible to generate objects of simd_view_impl class. To achieve that we need to force simd_view_impl APIs to return objects of derived class. This change is needed for implementing writeable subscript operator in simd_view.
1 parent 525d098 commit 07d2ac7

File tree

5 files changed

+91
-72
lines changed

5 files changed

+91
-72
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_view_impl.hpp

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ namespace detail {
2525
/// It is an internal class implementing basic functionality of simd_view.
2626
///
2727
/// \ingroup sycl_esimd
28-
template <typename BaseTy, typename RegionTy> class simd_view_impl {
28+
template <typename BaseTy, typename RegionTy, typename Derived>
29+
class simd_view_impl {
2930
template <typename, int> friend class simd;
30-
template <typename, typename> friend class simd_view_impl;
31+
template <typename, typename, typename> friend class simd_view_impl;
3132

3233
public:
3334
static_assert(!detail::is_simd_view_v<BaseTy>::value);
@@ -52,6 +53,9 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
5253
/// @{
5354
/// Constructors.
5455

56+
private:
57+
Derived &cast_this_to_derived() { return reinterpret_cast<Derived &>(*this); }
58+
5559
protected:
5660
simd_view_impl(BaseTy &Base, RegionTy Region)
5761
: M_base(Base), M_region(Region) {}
@@ -109,9 +113,9 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
109113
}
110114

111115
/// Write to this object.
112-
simd_view_impl &write(const value_type &Val) {
116+
Derived &write(const value_type &Val) {
113117
M_base.writeRegion(M_region, Val.data());
114-
return *this;
118+
return cast_this_to_derived();
115119
}
116120

117121
/// @{
@@ -129,9 +133,9 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
129133

130134
/// View this object in a different element type.
131135
template <typename EltTy> auto bit_cast_view() {
132-
using TopRegionTy = detail::compute_format_type_t<simd_view_impl, EltTy>;
136+
using TopRegionTy = detail::compute_format_type_t<Derived, EltTy>;
133137
using NewRegionTy = std::pair<TopRegionTy, RegionTy>;
134-
using RetTy = simd_view_impl<BaseTy, NewRegionTy>;
138+
using RetTy = simd_view<BaseTy, NewRegionTy>;
135139
TopRegionTy TopReg(0);
136140
return RetTy{this->M_base, std::make_pair(TopReg, M_region)};
137141
}
@@ -145,9 +149,9 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
145149
/// View as a 2-dimensional simd_view.
146150
template <typename EltTy, int Height, int Width> auto bit_cast_view() {
147151
using TopRegionTy =
148-
detail::compute_format_type_2d_t<simd_view_impl, EltTy, Height, Width>;
152+
detail::compute_format_type_2d_t<Derived, EltTy, Height, Width>;
149153
using NewRegionTy = std::pair<TopRegionTy, RegionTy>;
150-
using RetTy = simd_view_impl<BaseTy, NewRegionTy>;
154+
using RetTy = simd_view<BaseTy, NewRegionTy>;
151155
TopRegionTy TopReg(0, 0);
152156
return RetTy{this->M_base, std::make_pair(TopReg, M_region)};
153157
}
@@ -164,12 +168,12 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
164168
/// \tparam Stride is the element distance between two consecutive elements.
165169
/// \param Offset is the starting element offset.
166170
/// \return the representing region object.
167-
template <int Size, int Stride, typename T = simd_view_impl,
171+
template <int Size, int Stride, typename T = Derived,
168172
typename = sycl::detail::enable_if_t<T::is1D()>>
169173
auto select(uint16_t Offset = 0) {
170174
using TopRegionTy = region1d_t<element_type, Size, Stride>;
171175
using NewRegionTy = std::pair<TopRegionTy, RegionTy>;
172-
using RetTy = simd_view_impl<BaseTy, NewRegionTy>;
176+
using RetTy = simd_view<BaseTy, NewRegionTy>;
173177
TopRegionTy TopReg(Offset);
174178
return RetTy{this->M_base, std::make_pair(TopReg, M_region)};
175179
}
@@ -186,21 +190,20 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
186190
/// \param OffsetY is the starting element offset in Y-dimension.
187191
/// \return the representing region object.
188192
template <int SizeY, int StrideY, int SizeX, int StrideX,
189-
typename T = simd_view_impl,
193+
typename T = Derived,
190194
typename = sycl::detail::enable_if_t<T::is2D()>>
191195
auto select(uint16_t OffsetY = 0, uint16_t OffsetX = 0) {
192196
using TopRegionTy =
193197
region2d_t<element_type, SizeY, StrideY, SizeX, StrideX>;
194198
using NewRegionTy = std::pair<TopRegionTy, RegionTy>;
195-
using RetTy = simd_view_impl<BaseTy, NewRegionTy>;
199+
using RetTy = simd_view<BaseTy, NewRegionTy>;
196200
TopRegionTy TopReg(OffsetY, OffsetX);
197201
return RetTy{this->M_base, std::make_pair(TopReg, M_region)};
198202
}
199203

200204
#define DEF_BINOP(BINOP, OPASSIGN) \
201-
template <class T1 = simd_view_impl, \
202-
class = std::enable_if_t<T1::length != 1>> \
203-
ESIMD_INLINE friend auto operator BINOP(const simd_view_impl &X, \
205+
template <class T1 = Derived, class = std::enable_if_t<T1::length != 1>> \
206+
ESIMD_INLINE friend auto operator BINOP(const Derived &X, \
204207
const value_type &Y) { \
205208
using ComputeTy = detail::compute_type_t<value_type>; \
206209
auto V0 = \
@@ -209,31 +212,30 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
209212
auto V2 = V0 BINOP V1; \
210213
return ComputeTy(V2); \
211214
} \
212-
template <class T1 = simd_view_impl, \
213-
class = std::enable_if_t<T1::length != 1>> \
215+
template <class T1 = Derived, class = std::enable_if_t<T1::length != 1>> \
214216
ESIMD_INLINE friend auto operator BINOP(const value_type &X, \
215-
const simd_view_impl &Y) { \
217+
const Derived &Y) { \
216218
using ComputeTy = detail::compute_type_t<value_type>; \
217219
auto V0 = detail::convert<typename ComputeTy::vector_type>(X.data()); \
218220
auto V1 = \
219221
detail::convert<typename ComputeTy::vector_type>(Y.read().data()); \
220222
auto V2 = V0 BINOP V1; \
221223
return ComputeTy(V2); \
222224
} \
223-
ESIMD_INLINE friend auto operator BINOP(const simd_view_impl &X, \
224-
const simd_view_impl &Y) { \
225+
ESIMD_INLINE friend auto operator BINOP(const Derived &X, \
226+
const Derived &Y) { \
225227
return (X BINOP Y.read()); \
226228
} \
227-
simd_view_impl &operator OPASSIGN(const value_type &RHS) { \
229+
Derived &operator OPASSIGN(const value_type &RHS) { \
228230
using ComputeTy = detail::compute_type_t<value_type>; \
229231
auto V0 = detail::convert<typename ComputeTy::vector_type>(read().data()); \
230232
auto V1 = detail::convert<typename ComputeTy::vector_type>(RHS.data()); \
231233
auto V2 = V0 BINOP V1; \
232234
auto V3 = detail::convert<vector_type>(V2); \
233235
write(V3); \
234-
return *this; \
236+
return cast_this_to_derived(); \
235237
} \
236-
simd_view_impl &operator OPASSIGN(const simd_view_impl &RHS) { \
238+
Derived &operator OPASSIGN(const Derived &RHS) { \
237239
return (*this OPASSIGN RHS.read()); \
238240
}
239241

@@ -246,34 +248,32 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
246248
#undef DEF_BINOP
247249

248250
#define DEF_BITWISE_OP(BITWISE_OP, OPASSIGN) \
249-
template <class T1 = simd_view_impl, \
250-
class = std::enable_if_t<T1::length != 1>> \
251-
ESIMD_INLINE friend auto operator BITWISE_OP(const simd_view_impl &X, \
251+
template <class T1 = Derived, class = std::enable_if_t<T1::length != 1>> \
252+
ESIMD_INLINE friend auto operator BITWISE_OP(const Derived &X, \
252253
const value_type &Y) { \
253254
static_assert(std::is_integral<element_type>(), "not integral type"); \
254255
auto V2 = X.read().data() BITWISE_OP Y.data(); \
255256
return simd<element_type, length>(V2); \
256257
} \
257-
template <class T1 = simd_view_impl, \
258-
class = std::enable_if_t<T1::length != 1>> \
258+
template <class T1 = Derived, class = std::enable_if_t<T1::length != 1>> \
259259
ESIMD_INLINE friend auto operator BITWISE_OP(const value_type &X, \
260-
const simd_view_impl &Y) { \
260+
const Derived &Y) { \
261261
static_assert(std::is_integral<element_type>(), "not integral type"); \
262262
auto V2 = X.data() BITWISE_OP Y.read().data(); \
263263
return simd<element_type, length>(V2); \
264264
} \
265-
ESIMD_INLINE friend auto operator BITWISE_OP(const simd_view_impl &X, \
266-
const simd_view_impl &Y) { \
265+
ESIMD_INLINE friend auto operator BITWISE_OP(const Derived &X, \
266+
const Derived &Y) { \
267267
return (X BITWISE_OP Y.read()); \
268268
} \
269-
simd_view_impl &operator OPASSIGN(const value_type &RHS) { \
269+
Derived &operator OPASSIGN(const value_type &RHS) { \
270270
static_assert(std::is_integral<element_type>(), "not integeral type"); \
271271
auto V2 = read().data() BITWISE_OP RHS.data(); \
272272
auto V3 = detail::convert<vector_type>(V2); \
273273
write(V3); \
274-
return *this; \
274+
return cast_this_to_derived(); \
275275
} \
276-
simd_view_impl &operator OPASSIGN(const simd_view_impl &RHS) { \
276+
Derived &operator OPASSIGN(const Derived &RHS) { \
277277
return (*this OPASSIGN RHS.read()); \
278278
}
279279
DEF_BITWISE_OP(&, &=)
@@ -295,19 +295,22 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
295295

296296
#undef DEF_UNARY_OP
297297

298+
// negation operator
299+
auto operator!() { return cast_this_to_derived() == 0; }
300+
298301
// Operator ++, --
299-
simd_view_impl &operator++() {
302+
Derived &operator++() {
300303
*this += 1;
301-
return *this;
304+
return cast_this_to_derived();
302305
}
303306
value_type operator++(int) {
304307
value_type Ret(read());
305308
operator++();
306309
return Ret;
307310
}
308-
simd_view_impl &operator--() {
311+
Derived &operator--() {
309312
*this -= 1;
310-
return *this;
313+
return cast_this_to_derived();
311314
}
312315
value_type operator--(int) {
313316
value_type Ret(read());
@@ -317,7 +320,7 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
317320

318321
/// Reference a row from a 2D region.
319322
/// \return a 1D region.
320-
template <typename T = simd_view_impl,
323+
template <typename T = Derived,
321324
typename = sycl::detail::enable_if_t<T::is2D()>>
322325
auto row(int i) {
323326
return select<1, 0, getSizeX(), 1>(i, 0)
@@ -326,22 +329,22 @@ template <typename BaseTy, typename RegionTy> class simd_view_impl {
326329

327330
/// Reference a column from a 2D region.
328331
/// \return a 2D region.
329-
template <typename T = simd_view_impl,
332+
template <typename T = Derived,
330333
typename = sycl::detail::enable_if_t<T::is2D()>>
331334
auto column(int i) {
332335
return select<getSizeY(), 1, 1, 0>(0, i);
333336
}
334337

335338
/// Read a single element from a 1D region, by value only.
336-
template <typename T = simd_view_impl,
339+
template <typename T = Derived,
337340
typename = sycl::detail::enable_if_t<T::is1D()>>
338341
element_type operator[](int i) const {
339342
const auto v = read();
340343
return v[i];
341344
}
342345

343346
/// Read a single element from a 1D region, by value only.
344-
template <typename T = simd_view_impl,
347+
template <typename T = Derived,
345348
typename = sycl::detail::enable_if_t<T::is1D()>>
346349
__SYCL_DEPRECATED("use operator[] form.")
347350
element_type operator()(int i) const {

sycl/include/sycl/ext/intel/experimental/esimd/detail/types.hpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ namespace intel {
2525
namespace experimental {
2626
namespace esimd {
2727

28-
// simd and simd_view_impl forward declarations
28+
// simd and simd_view forward declarations
2929
template <typename Ty, int N> class simd;
30-
namespace detail {
31-
template <typename BaseTy, typename RegionTy> class simd_view_impl;
32-
} // namespace detail
30+
template <typename BaseTy, typename RegionTy> class simd_view;
3331

3432
namespace detail {
3533

@@ -90,7 +88,7 @@ struct compute_format_type<simd<Ty, N>, EltTy> {
9088
};
9189

9290
template <typename BaseTy, typename RegionTy, typename EltTy>
93-
struct compute_format_type<detail::simd_view_impl<BaseTy, RegionTy>, EltTy> {
91+
struct compute_format_type<simd_view<BaseTy, RegionTy>, EltTy> {
9492
using ShapeTy = typename shape_type<RegionTy>::type;
9593
static constexpr int Size = ShapeTy::Size_in_bytes / sizeof(EltTy);
9694
static constexpr int Stride = 1;
@@ -118,8 +116,8 @@ struct compute_format_type_2d<simd<Ty, N>, EltTy, Height, Width> {
118116

119117
template <typename BaseTy, typename RegionTy, typename EltTy, int Height,
120118
int Width>
121-
struct compute_format_type_2d<detail::simd_view_impl<BaseTy, RegionTy>, EltTy,
122-
Height, Width> {
119+
struct compute_format_type_2d<simd_view<BaseTy, RegionTy>, EltTy, Height,
120+
Width> {
123121
using ShapeTy = typename shape_type<RegionTy>::type;
124122
static constexpr int Prod = ShapeTy::Size_in_bytes / sizeof(EltTy);
125123
static_assert(Prod == Width * Height, "size mismatch");
@@ -138,10 +136,6 @@ using compute_format_type_2d_t =
138136
// Check if a type is simd_view type
139137
template <typename Ty> struct is_simd_view_type : std::false_type {};
140138

141-
template <typename BaseTy, typename RegionTy>
142-
struct is_simd_view_type<detail::simd_view_impl<BaseTy, RegionTy>>
143-
: std::true_type {};
144-
145139
template <typename BaseTy, typename RegionTy>
146140
struct is_simd_view_type<simd_view<BaseTy, RegionTy>> : std::true_type {};
147141

@@ -157,8 +151,7 @@ template <typename Ty, int N>
157151
struct is_simd_type<simd<Ty, N>> : std::true_type {};
158152

159153
template <typename BaseTy, typename RegionTy>
160-
struct is_simd_type<detail::simd_view_impl<BaseTy, RegionTy>> : std::true_type {
161-
};
154+
struct is_simd_type<simd_view<BaseTy, RegionTy>> : std::true_type {};
162155

163156
template <typename Ty>
164157
struct is_simd_v
@@ -170,7 +163,7 @@ template <typename Ty, int N> struct element_type<simd<Ty, N>> {
170163
using type = Ty;
171164
};
172165
template <typename BaseTy, typename RegionTy>
173-
struct element_type<detail::simd_view_impl<BaseTy, RegionTy>> {
166+
struct element_type<simd_view<BaseTy, RegionTy>> {
174167
using type = typename RegionTy::element_type;
175168
};
176169

sycl/include/sycl/ext/intel/experimental/esimd/simd_view.hpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ namespace esimd {
2525
///
2626
/// \ingroup sycl_esimd
2727
template <typename BaseTy, typename RegionTy>
28-
class simd_view : public detail::simd_view_impl<BaseTy, RegionTy> {
28+
class simd_view : public detail::simd_view_impl<BaseTy, RegionTy,
29+
simd_view<BaseTy, RegionTy>> {
2930
template <typename, int> friend class simd;
30-
// template <typename, typename> friend class simd_view;
31+
template <typename, typename, typename> friend class detail::simd_view_impl;
3132

3233
public:
33-
using BaseClass = detail::simd_view_impl<BaseTy, RegionTy>;
34+
using BaseClass =
35+
detail::simd_view_impl<BaseTy, RegionTy, simd_view<BaseTy, RegionTy>>;
3436
using ShapeTy = typename shape_type<RegionTy>::type;
3537
static constexpr int length = ShapeTy::Size_x * ShapeTy::Size_y;
3638

@@ -90,9 +92,6 @@ class simd_view : public detail::simd_view_impl<BaseTy, RegionTy> {
9092
DEF_RELOP(!=)
9193

9294
#undef DEF_RELOP
93-
94-
// negation operator
95-
auto operator!() { return *this == 0; }
9695
};
9796

9897
/// This is a specialization of simd_view class with a single element.
@@ -107,12 +106,15 @@ class simd_view : public detail::simd_view_impl<BaseTy, RegionTy> {
107106
template <typename BaseTy>
108107
class simd_view<BaseTy, region_base_1<typename BaseTy::element_type>>
109108
: public detail::simd_view_impl<
110-
BaseTy, region_base_1<typename BaseTy::element_type>> {
109+
BaseTy, region_base_1<typename BaseTy::element_type>,
110+
simd_view<BaseTy, region_base_1<typename BaseTy::element_type>>> {
111111
template <typename, int> friend class simd;
112+
template <typename, typename, typename> friend class detail::simd_view_impl;
112113

113114
public:
114115
using RegionTy = region_base_1<typename BaseTy::element_type>;
115-
using BaseClass = detail::simd_view_impl<BaseTy, RegionTy>;
116+
using BaseClass =
117+
detail::simd_view_impl<BaseTy, RegionTy, simd_view<BaseTy, RegionTy>>;
116118
using ShapeTy = typename shape_type<RegionTy>::type;
117119
static constexpr int length = ShapeTy::Size_x * ShapeTy::Size_y;
118120
static_assert(1 == length, "length of this view is not equal to 1");
@@ -121,10 +123,8 @@ class simd_view<BaseTy, region_base_1<typename BaseTy::element_type>>
121123
using element_type = typename ShapeTy::element_type;
122124

123125
private:
124-
simd_view(BaseTy &Base, RegionTy Region)
125-
: detail::simd_view_impl<BaseTy, RegionTy>(Base, Region) {}
126-
simd_view(BaseTy &&Base, RegionTy Region)
127-
: detail::simd_view_impl<BaseTy, RegionTy>(Base, Region) {}
126+
simd_view(BaseTy &Base, RegionTy Region) : BaseClass(Base, Region) {}
127+
simd_view(BaseTy &&Base, RegionTy Region) : BaseClass(Base, Region) {}
128128

129129
public:
130130
operator element_type() const { return (*this)[0]; }
@@ -145,9 +145,6 @@ class simd_view<BaseTy, region_base_1<typename BaseTy::element_type>>
145145
DEF_RELOP(!=)
146146

147147
#undef DEF_RELOP
148-
149-
// negation operator
150-
auto operator!() { return *this == 0; }
151148
};
152149

153150
} // namespace esimd

sycl/test/esimd/esimd-util-compiler-eval.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ static_assert(log2<1024 * 1024>() == 20, "");
2020
using BaseTy = simd<float, 4>;
2121
using RegionTy = region1d_t<float, 2, 1>;
2222
using RegionTy1 = region_base_1<float>;
23-
static_assert(is_simd_view_v<simd_view_impl<BaseTy, RegionTy>>::value, "");
23+
static_assert(
24+
!is_simd_view_v<
25+
simd_view_impl<BaseTy, RegionTy, simd_view<BaseTy, RegionTy>>>::value,
26+
"");
2427
static_assert(is_simd_view_v<simd_view<BaseTy, RegionTy>>::value, "");
2528
static_assert(is_simd_view_v<simd_view<BaseTy, RegionTy1>>::value, "");
2629
static_assert(!is_simd_view_v<BaseTy>::value, "");

0 commit comments

Comments
 (0)