@@ -177,7 +177,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
177177 assert ((!isa<ConstantInt>(Stride) ||
178178 cast<ConstantInt>(Stride)->getZExtValue () >= NumElements) &&
179179 " Stride must be >= the number of elements in the result vector." );
180- unsigned AS = cast<PointerType>(BasePtr->getType ())->getAddressSpace ();
181180
182181 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
183182 Value *VecStart = Builder.CreateMul (VecIdx, Stride, " vec.start" );
@@ -189,11 +188,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
189188 else
190189 VecStart = Builder.CreateGEP (EltType, BasePtr, VecStart, " vec.gep" );
191190
192- // Cast elementwise vector start pointer to a pointer to a vector
193- // (EltType x NumElements)*.
194- auto *VecType = FixedVectorType::get (EltType, NumElements);
195- Type *VecPtrType = PointerType::get (VecType, AS);
196- return Builder.CreatePointerCast (VecStart, VecPtrType, " vec.cast" );
191+ return VecStart;
197192}
198193
199194// / LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -1060,13 +1055,6 @@ class LowerMatrixIntrinsics {
10601055 return Changed;
10611056 }
10621057
1063- // / Turns \p BasePtr into an elementwise pointer to \p EltType.
1064- Value *createElementPtr (Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
1065- unsigned AS = cast<PointerType>(BasePtr->getType ())->getAddressSpace ();
1066- Type *EltPtrType = PointerType::get (EltType, AS);
1067- return Builder.CreatePointerCast (BasePtr, EltPtrType);
1068- }
1069-
10701058 // / Replace intrinsic calls
10711059 bool VisitCallInst (CallInst *Inst) {
10721060 if (!Inst->getCalledFunction () || !Inst->getCalledFunction ()->isIntrinsic ())
@@ -1118,7 +1106,7 @@ class LowerMatrixIntrinsics {
11181106 auto *VType = cast<VectorType>(Ty);
11191107 Type *EltTy = VType->getElementType ();
11201108 Type *VecTy = FixedVectorType::get (EltTy, Shape.getStride ());
1121- Value *EltPtr = createElementPtr ( Ptr, EltTy, Builder) ;
1109+ Value *EltPtr = Ptr;
11221110 MatrixTy Result;
11231111 for (unsigned I = 0 , E = Shape.getNumVectors (); I < E; ++I) {
11241112 Value *GEP = computeVectorAddr (
@@ -1144,17 +1132,11 @@ class LowerMatrixIntrinsics {
11441132 Value *Offset = Builder.CreateAdd (
11451133 Builder.CreateMul (J, Builder.getInt64 (MatrixShape.getStride ())), I);
11461134
1147- unsigned AS = cast<PointerType>(MatrixPtr->getType ())->getAddressSpace ();
1148- Value *EltPtr =
1149- Builder.CreatePointerCast (MatrixPtr, PointerType::get (EltTy, AS));
1150- Value *TileStart = Builder.CreateGEP (EltTy, EltPtr, Offset);
1135+ Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
11511136 auto *TileTy = FixedVectorType::get (EltTy, ResultShape.NumRows *
11521137 ResultShape.NumColumns );
1153- Type *TilePtrTy = PointerType::get (TileTy, AS);
1154- Value *TilePtr =
1155- Builder.CreatePointerCast (TileStart, TilePtrTy, " col.cast" );
11561138
1157- return loadMatrix (TileTy, TilePtr , Align,
1139+ return loadMatrix (TileTy, TileStart , Align,
11581140 Builder.getInt64 (MatrixShape.getStride ()), IsVolatile,
11591141 ResultShape, Builder);
11601142 }
@@ -1190,17 +1172,11 @@ class LowerMatrixIntrinsics {
11901172 Value *Offset = Builder.CreateAdd (
11911173 Builder.CreateMul (J, Builder.getInt64 (MatrixShape.getStride ())), I);
11921174
1193- unsigned AS = cast<PointerType>(MatrixPtr->getType ())->getAddressSpace ();
1194- Value *EltPtr =
1195- Builder.CreatePointerCast (MatrixPtr, PointerType::get (EltTy, AS));
1196- Value *TileStart = Builder.CreateGEP (EltTy, EltPtr, Offset);
1175+ Value *TileStart = Builder.CreateGEP (EltTy, MatrixPtr, Offset);
11971176 auto *TileTy = FixedVectorType::get (EltTy, StoreVal.getNumRows () *
11981177 StoreVal.getNumColumns ());
1199- Type *TilePtrTy = PointerType::get (TileTy, AS);
1200- Value *TilePtr =
1201- Builder.CreatePointerCast (TileStart, TilePtrTy, " col.cast" );
12021178
1203- storeMatrix (TileTy, StoreVal, TilePtr , MAlign,
1179+ storeMatrix (TileTy, StoreVal, TileStart , MAlign,
12041180 Builder.getInt64 (MatrixShape.getStride ()), IsVolatile, Builder);
12051181 }
12061182
@@ -1210,7 +1186,7 @@ class LowerMatrixIntrinsics {
12101186 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
12111187 IRBuilder<> &Builder) {
12121188 auto VType = cast<VectorType>(Ty);
1213- Value *EltPtr = createElementPtr ( Ptr, VType-> getElementType (), Builder) ;
1189+ Value *EltPtr = Ptr;
12141190 for (auto Vec : enumerate(StoreVal.vectors ())) {
12151191 Value *GEP = computeVectorAddr (
12161192 EltPtr,
0 commit comments