diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 86d318967403d..0d91e7d77e4a7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2088,8 +2088,6 @@ CommonPointerBase CommonPointerBase::compute(Value *LHS, Value *RHS) { // Find common base and collect RHS GEPs. while (true) { if (Ptrs.contains(RHS)) { - if (LHS->getType() != RHS->getType()) - return Base; Base.Ptr = RHS; break; } @@ -2132,12 +2130,15 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, // TODO: We should probably do this even if there is only one GEP. bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty(); - Type *IdxTy = DL.getIndexType(Base.Ptr->getType()); + Type *IdxTy = DL.getIndexType(LHS->getType()); auto EmitOffsetFromBase = [&](ArrayRef GEPs, GEPNoWrapFlags NW) -> Value * { Value *Sum = nullptr; for (GEPOperator *GEP : reverse(GEPs)) { Value *Offset = EmitGEPOffset(GEP, RewriteGEPs); + if (Offset->getType() != IdxTy) + Offset = Builder.CreateVectorSplat( + cast(IdxTy)->getElementCount(), Offset); if (Sum) Sum = Builder.CreateAdd(Sum, Offset, "", NW.hasNoUnsignedWrap(), NW.isInBounds()); diff --git a/llvm/test/Transforms/InstCombine/sub-gep.ll b/llvm/test/Transforms/InstCombine/sub-gep.ll index 9444fef1887d3..375be8a3d69c3 100644 --- a/llvm/test/Transforms/InstCombine/sub-gep.ll +++ b/llvm/test/Transforms/InstCombine/sub-gep.ll @@ -995,3 +995,33 @@ define i64 @multiple_geps_inbounds_nuw(ptr %base, i64 %idx, i64 %idx2) { %d = sub i64 %i2, %i1 ret i64 %d } + +define <2 x i64> @splat_geps(ptr %base, <2 x i64> %idx1, <2 x i64> %idx2) { +; CHECK-LABEL: @splat_geps( +; CHECK-NEXT: [[D:%.*]] = sub nsw <2 x i64> [[IDX2:%.*]], [[IDX1:%.*]] +; CHECK-NEXT: ret <2 x i64> [[D]] +; + %gep1 = getelementptr inbounds i8, ptr %base, <2 x i64> %idx1 + %gep2 = getelementptr inbounds i8, ptr %base, <2 x i64> %idx2 + %gep1.int = ptrtoint <2 x ptr> %gep1 to <2 x i64> + %gep2.int = ptrtoint <2 x ptr> %gep2 to <2 x i64> + %d = sub <2 x i64> %gep2.int, %gep1.int + ret <2 x i64> %d +} + +define <2 x i64> @splat_geps_multiple(ptr %base, i64 %idx0, <2 x i64> %idx1, <2 x i64> %idx2) { +; CHECK-LABEL: @splat_geps_multiple( +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[IDX0:%.*]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <2 x i64> [[DOTSPLATINSERT]], <2 x i64> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = add nsw <2 x i64> [[DOTSPLAT]], [[IDX1:%.*]] +; CHECK-NEXT: [[D:%.*]] = sub nsw <2 x i64> [[IDX2:%.*]], [[TMP1]] +; CHECK-NEXT: ret <2 x i64> [[D]] +; + %gep0 = getelementptr inbounds i8, ptr %base, i64 %idx0 + %gep1 = getelementptr inbounds i8, ptr %gep0, <2 x i64> %idx1 + %gep2 = getelementptr inbounds i8, ptr %base, <2 x i64> %idx2 + %gep1.int = ptrtoint <2 x ptr> %gep1 to <2 x i64> + %gep2.int = ptrtoint <2 x ptr> %gep2 to <2 x i64> + %d = sub <2 x i64> %gep2.int, %gep1.int + ret <2 x i64> %d +}