@@ -27,20 +27,24 @@ struct RestrictPtrTraits {
2727};
2828#endif 
2929
30- template <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits, typename  index_t  = int64_t >
30+ template  <
31+     typename  T,
32+     size_t  N,
33+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits,
34+     typename  index_t  = int64_t >
3135class  TensorAccessorBase  {
32- public: 
36+   public: 
3337  typedef  typename  PtrTraits<T>::PtrType PtrType;
3438
3539  C10_HOST_DEVICE TensorAccessorBase (
3640      PtrType data_,
3741      const  index_t * sizes_,
3842      const  index_t * strides_)
39-     : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/   {
43+        : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/   {
4044    //  Originally, TensorAccessor is a view of sizes and strides as
4145    //  these are ArrayRef instances. Until torch::stable supports
4246    //  ArrayRef-like features, we store copies of sizes and strides:
43-     for  (auto  i= 0 ; i < N; ++i) {
47+     for  (auto  i =  0 ; i < N; ++i) {
4448      this ->sizes_ [i] = sizes_[i];
4549      this ->strides_ [i] = strides_[i];
4650    }
@@ -52,7 +56,8 @@ class TensorAccessorBase {
5256  C10_HOST_DEVICE const  PtrType data () const  {
5357    return  data_;
5458  }
55- protected: 
59+ 
60+  protected: 
5661  PtrType data_;
5762  /* 
5863    const index_t* sizes_; 
@@ -64,48 +69,65 @@ class TensorAccessorBase {
6469  index_t  strides_[N];
6570};
6671
67- template <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits, typename  index_t  = int64_t >
68- class  TensorAccessor  : public  TensorAccessorBase <T,N,PtrTraits,index_t > {
69- public: 
72+ template  <
73+     typename  T,
74+     size_t  N,
75+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits,
76+     typename  index_t  = int64_t >
77+ class  TensorAccessor  : public  TensorAccessorBase <T, N, PtrTraits, index_t > {
78+  public: 
7079  typedef  typename  PtrTraits<T>::PtrType PtrType;
7180
7281  C10_HOST_DEVICE TensorAccessor (
7382      PtrType data_,
7483      const  index_t * sizes_,
7584      const  index_t * strides_)
76-       : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
85+       : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,  sizes_,  strides_) {}
7786
78-   C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t  i) {
79-     return  TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_  + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
87+   C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
88+       index_t  i) {
89+     return  TensorAccessor<T, N - 1 , PtrTraits, index_t >(
90+         this ->data_  + this ->strides_ [0 ] * i,
91+         this ->sizes_  + 1 ,
92+         this ->strides_  + 1 );
8093  }
8194
82-   C10_HOST_DEVICE const  TensorAccessor<T, N-1 , PtrTraits, index_t > operator [](index_t  i) const  {
83-     return  TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_  + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
95+   C10_HOST_DEVICE const  TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
96+       index_t  i) const  {
97+     return  TensorAccessor<T, N - 1 , PtrTraits, index_t >(
98+         this ->data_  + this ->strides_ [0 ] * i,
99+         this ->sizes_  + 1 ,
100+         this ->strides_  + 1 );
84101  }
85102};
86103
87- template <typename  T, template  <typename  U> class  PtrTraits , typename  index_t >
88- class  TensorAccessor <T,1 ,PtrTraits,index_t > : public TensorAccessorBase<T,1 ,PtrTraits,index_t > {
89- public: 
104+ template  <typename  T, template  <typename  U> class  PtrTraits , typename  index_t >
105+ class  TensorAccessor <T, 1 , PtrTraits, index_t >
106+     : public TensorAccessorBase<T, 1 , PtrTraits, index_t > {
107+  public: 
90108  typedef  typename  PtrTraits<T>::PtrType PtrType;
91109
92110  C10_HOST_DEVICE TensorAccessor (
93111      PtrType data_,
94112      const  index_t * sizes_,
95113      const  index_t * strides_)
96-       : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
97-   C10_HOST_DEVICE T  & operator [](index_t  i) {
114+       : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,  sizes_,  strides_) {}
115+   C10_HOST_DEVICE T& operator [](index_t  i) {
98116    //  NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
99-     return  this ->data_ [this ->strides_ [0 ]* i];
117+     return  this ->data_ [this ->strides_ [0 ] *  i];
100118  }
101-   C10_HOST_DEVICE const  T  & operator [](index_t  i) const  {
102-     return  this ->data_ [this ->strides_ [0 ]* i];
119+   C10_HOST_DEVICE const  T& operator [](index_t  i) const  {
120+     return  this ->data_ [this ->strides_ [0 ] *  i];
103121  }
104122};
105123
106- template <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits, typename  index_t  = int64_t >
124+ template  <
125+     typename  T,
126+     size_t  N,
127+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits,
128+     typename  index_t  = int64_t >
107129class  GenericPackedTensorAccessorBase  {
108- public: 
130+   public: 
109131  typedef  typename  PtrTraits<T>::PtrType PtrType;
110132  C10_HOST GenericPackedTensorAccessorBase (
111133      PtrType data_,
@@ -116,13 +138,15 @@ class GenericPackedTensorAccessorBase {
116138    std::copy (strides_, strides_ + N, std::begin (this ->strides_ ));
117139  }
118140
119-   template  <typename  source_index_t , class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
141+   template  <
142+       typename  source_index_t ,
143+       class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
120144  C10_HOST GenericPackedTensorAccessorBase (
121145      PtrType data_,
122146      const  source_index_t * sizes_,
123147      const  source_index_t * strides_)
124148      : data_(data_) {
125-     for  (auto  i= 0 ; i < N; ++i) {
149+     for  (auto  i =  0 ; i < N; ++i) {
126150      this ->sizes_ [i] = sizes_[i];
127151      this ->strides_ [i] = strides_[i];
128152    }
@@ -134,7 +158,8 @@ class GenericPackedTensorAccessorBase {
134158  C10_HOST_DEVICE const  PtrType data () const  {
135159    return  data_;
136160  }
137- protected: 
161+ 
162+  protected: 
138163  PtrType data_;
139164  //  NOLINTNEXTLINE(*c-arrays*)
140165  index_t  sizes_[N];
@@ -150,68 +175,101 @@ class GenericPackedTensorAccessorBase {
150175  }
151176};
152177
153- template <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits, typename  index_t  = int64_t >
154- class  GenericPackedTensorAccessor  : public  GenericPackedTensorAccessorBase <T,N,PtrTraits,index_t > {
155- public: 
178+ template  <
179+     typename  T,
180+     size_t  N,
181+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits,
182+     typename  index_t  = int64_t >
183+ class  GenericPackedTensorAccessor 
184+     : public GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t > {
185+  public: 
156186  typedef  typename  PtrTraits<T>::PtrType PtrType;
157187
158188  C10_HOST GenericPackedTensorAccessor (
159189      PtrType data_,
160190      const  index_t * sizes_,
161191      const  index_t * strides_)
162-       : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
192+       : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
193+             data_,
194+             sizes_,
195+             strides_) {}
163196
164197  //  if index_t is not int64_t, we want to have an int64_t constructor
165-   template  <typename  source_index_t , class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
198+   template  <
199+       typename  source_index_t ,
200+       class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
166201  C10_HOST GenericPackedTensorAccessor (
167202      PtrType data_,
168203      const  source_index_t * sizes_,
169204      const  source_index_t * strides_)
170-       : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
205+       : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
206+             data_,
207+             sizes_,
208+             strides_) {}
171209
172-   C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t  i) {
210+   C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
211+       index_t  i) {
173212    index_t * new_sizes = this ->sizes_  + 1 ;
174213    index_t * new_strides = this ->strides_  + 1 ;
175-     return  TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_  + this ->strides_ [0 ]*i, new_sizes, new_strides);
214+     return  TensorAccessor<T, N - 1 , PtrTraits, index_t >(
215+         this ->data_  + this ->strides_ [0 ] * i, new_sizes, new_strides);
176216  }
177217
178-   C10_DEVICE const  TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t  i) const  {
218+   C10_DEVICE const  TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
219+       index_t  i) const  {
179220    const  index_t * new_sizes = this ->sizes_  + 1 ;
180221    const  index_t * new_strides = this ->strides_  + 1 ;
181-     return  TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_  + this ->strides_ [0 ]*i, new_sizes, new_strides);
222+     return  TensorAccessor<T, N - 1 , PtrTraits, index_t >(
223+         this ->data_  + this ->strides_ [0 ] * i, new_sizes, new_strides);
182224  }
183225};
184226
185- template <typename  T, template  <typename  U> class  PtrTraits , typename  index_t >
186- class  GenericPackedTensorAccessor <T,1 ,PtrTraits,index_t > : public GenericPackedTensorAccessorBase<T,1 ,PtrTraits,index_t > {
187- public: 
227+ template  <typename  T, template  <typename  U> class  PtrTraits , typename  index_t >
228+ class  GenericPackedTensorAccessor <T, 1 , PtrTraits, index_t >
229+     : public GenericPackedTensorAccessorBase<T, 1 , PtrTraits, index_t > {
230+  public: 
188231  typedef  typename  PtrTraits<T>::PtrType PtrType;
189232  C10_HOST GenericPackedTensorAccessor (
190233      PtrType data_,
191234      const  index_t * sizes_,
192235      const  index_t * strides_)
193-       : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
236+       : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
237+             data_,
238+             sizes_,
239+             strides_) {}
194240
195-   template  <typename  source_index_t , class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
241+   template  <
242+       typename  source_index_t ,
243+       class  = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
196244  C10_HOST GenericPackedTensorAccessor (
197245      PtrType data_,
198246      const  source_index_t * sizes_,
199247      const  source_index_t * strides_)
200-       : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
248+       : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
249+             data_,
250+             sizes_,
251+             strides_) {}
201252
202-   C10_DEVICE T  & operator [](index_t  i) {
253+   C10_DEVICE T& operator [](index_t  i) {
203254    return  this ->data_ [this ->strides_ [0 ] * i];
204255  }
205256  C10_DEVICE const  T& operator [](index_t  i) const  {
206-     return  this ->data_ [this ->strides_ [0 ]* i];
257+     return  this ->data_ [this ->strides_ [0 ] *  i];
207258  }
208- 
209259};
210260
211- template  <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits>
212- using  PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
261+ template  <
262+     typename  T,
263+     size_t  N,
264+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits>
265+ using  PackedTensorAccessor32 =
266+     GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
213267
214- template  <typename  T, size_t  N, template  <typename  U> class  PtrTraits  = DefaultPtrTraits>
215- using  PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
268+ template  <
269+     typename  T,
270+     size_t  N,
271+     template  <typename  U> class  PtrTraits  = DefaultPtrTraits>
272+ using  PackedTensorAccessor64 =
273+     GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
216274
217- }   //  namespace torchaudio::stable
275+ } //  namespace torchaudio::stable
0 commit comments