Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 67 additions & 30 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,71 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
}

/// Whenever an element is extracted from a vector, optionally shifted down, and
/// then truncated, canonicalize by converting it to a bitcast followed by an
/// extractelement.
///
/// Examples (little endian):
/// trunc (extractelement <4 x i64> %X, 0) to i32
/// --->
/// extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
///
/// trunc (lshr (extractelement <4 x i32> %X, 0), 8) to i8
/// --->
/// extractelement <16 x i8> (bitcast <4 x i32> %X to <16 x i8>), i32 1
static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
InstCombinerImpl &IC) {
Value *Src = Trunc.getOperand(0);
Type *SrcType = Src->getType();
Type *DstType = Trunc.getType();

// Only attempt this if we have simple aliasing of the vector elements.
// A badly fit destination size would result in an invalid cast.
unsigned SrcBits = SrcType->getScalarSizeInBits();
unsigned DstBits = DstType->getScalarSizeInBits();
unsigned TruncRatio = SrcBits / DstBits;
if ((SrcBits % DstBits) != 0)
return nullptr;

Value *VecOp;
ConstantInt *Cst;
const APInt *ShiftAmount = nullptr;
if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))) &&
!match(Src,
m_OneUse(m_LShr(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)),
m_APInt(ShiftAmount)))))
return nullptr;

auto *VecOpTy = cast<VectorType>(VecOp->getType());
auto VecElts = VecOpTy->getElementCount();

uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
uint64_t VecOpIdx = Cst->getZExtValue();
uint64_t NewIdx = IC.getDataLayout().isBigEndian()
? (VecOpIdx + 1) * TruncRatio - 1
: VecOpIdx * TruncRatio;

// Adjust index by the whole number of truncated elements.
if (ShiftAmount) {
// Check shift amount is in range and shifts a whole number of truncated
// elements.
if (ShiftAmount->uge(SrcBits) || ShiftAmount->urem(DstBits) != 0)
return nullptr;

uint64_t IdxOfs = ShiftAmount->udiv(DstBits).getZExtValue();
NewIdx = IC.getDataLayout().isBigEndian() ? (NewIdx - IdxOfs)
: (NewIdx + IdxOfs);
}

assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
NewIdx <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");

auto *BitCastTo =
VectorType::get(DstType, BitCastNumElts, VecElts.isScalable());
Value *BitCast = IC.Builder.CreateBitCast(VecOp, BitCastTo);
return ExtractElementInst::Create(BitCast, IC.Builder.getInt32(NewIdx));
}

/// Funnel/Rotate left/right may occur in a wider type than necessary because of
/// type promotion rules. Try to narrow the inputs and convert to funnel shift.
Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
Expand Down Expand Up @@ -848,36 +913,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
return I;

// Whenever an element is extracted from a vector, and then truncated,
// canonicalize by converting it to a bitcast followed by an
// extractelement.
//
// Example (little endian):
// trunc (extractelement <4 x i64> %X, 0) to i32
// --->
// extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
Value *VecOp;
ConstantInt *Cst;
if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
auto *VecOpTy = cast<VectorType>(VecOp->getType());
auto VecElts = VecOpTy->getElementCount();

// A badly fit destination size would result in an invalid cast.
if (SrcWidth % DestWidth == 0) {
uint64_t TruncRatio = SrcWidth / DestWidth;
uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
uint64_t VecOpIdx = Cst->getZExtValue();
uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
: VecOpIdx * TruncRatio;
assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
"overflow 32-bits");

auto *BitCastTo =
VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
}
}
if (Instruction *I = foldVecExtTruncToExtElt(Trunc, *this))
return I;

// trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C)
if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)),
Expand Down
103 changes: 103 additions & 0 deletions llvm/test/Transforms/InstCombine/trunc-extractelement-inseltpoison.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
ret i32 %t
}

define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
; LE-NEXT: [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 1
; LE-NEXT: ret i32 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
; BE-NEXT: [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 0
; BE-NEXT: ret i32 [[T]]
;
%e = extractelement <3 x i64> %x, i32 0
%s = lshr i64 %e, 32
%t = trunc i64 %s to i32
ret i32 %t
}

define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
; LE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
Expand All @@ -34,6 +51,22 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
ret i32 %t
}

define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
; LE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
; LE-NEXT: [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 1
; LE-NEXT: ret i32 [[T]]
;
; BE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
; BE-NEXT: [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 0
; BE-NEXT: ret i32 [[T]]
;
%e = extractelement <vscale x 3 x i64> %x, i32 0
%s = lshr i64 %e, 32
%t = trunc i64 %s to i32
ret i32 %t
}

define i32 @shrinkExtractElt_i64_to_i32_1(<3 x i64> %x) {
; LE-LABEL: @shrinkExtractElt_i64_to_i32_1(
Expand Down Expand Up @@ -83,6 +116,23 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
ret i16 %t
}

define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
; LE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 3
; LE-NEXT: ret i16 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
; BE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 0
; BE-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i16 0
%s = ashr i64 %e, 48
%t = trunc i64 %s to i16
ret i16 %t
}

define i16 @shrinkExtractElt_i64_to_i16_1(<3 x i64> %x) {
; LE-LABEL: @shrinkExtractElt_i64_to_i16_1(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
Expand Down Expand Up @@ -157,6 +207,20 @@ define i30 @shrinkExtractElt_i40_to_i30_1(<3 x i40> %x) {
ret i30 %t
}

; Do not optimize if the shift amount isn't a whole number of truncated bits.
define i16 @shrinkShiftExtractElt_i64_to_i16_0_badshift(<3 x i64> %x) {
; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0_badshift(
; ANY-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
; ANY-NEXT: [[S:%.*]] = lshr i64 [[E]], 31
; ANY-NEXT: [[T:%.*]] = trunc i64 [[S]] to i16
; ANY-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i16 0
%s = lshr i64 %e, 31
%t = trunc i64 %s to i16
ret i16 %t
}

; Do not canonicalize if that would increase the instruction count.
declare void @use(i64)
define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
Expand All @@ -172,6 +236,45 @@ define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
ret i16 %t
}

; Do not canonicalize if that would increase the instruction count.
define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(
; ANY-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; ANY-NEXT: [[S:%.*]] = lshr i64 [[E]], 48
; ANY-NEXT: call void @use(i64 [[S]])
; ANY-NEXT: [[T:%.*]] = trunc nuw i64 [[S]] to i16
; ANY-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i64 2
%s = lshr i64 %e, 48
call void @use(i64 %s)
%t = trunc i64 %s to i16
ret i16 %t
}

; OK to reuse the extract if we remove the shift+trunc.
define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
; LE-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; LE-NEXT: call void @use(i64 [[E]])
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
; LE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 11
; LE-NEXT: ret i16 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
; BE-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; BE-NEXT: call void @use(i64 [[E]])
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
; BE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 8
; BE-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i64 2
call void @use(i64 %e)
%s = lshr i64 %e, 48
%t = trunc i64 %s to i16
ret i16 %t
}

; Check to ensure PR45314 remains fixed.
define <4 x i64> @PR45314(<4 x i64> %x) {
; LE-LABEL: @PR45314(
Expand Down
103 changes: 103 additions & 0 deletions llvm/test/Transforms/InstCombine/trunc-extractelement.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
ret i32 %t
}

define i32 @shrinkShiftExtractElt_i64_to_i32_0(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
; LE-NEXT: [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 1
; LE-NEXT: ret i32 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i32_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>
; BE-NEXT: [[T:%.*]] = extractelement <6 x i32> [[TMP1]], i64 0
; BE-NEXT: ret i32 [[T]]
;
%e = extractelement <3 x i64> %x, i32 0
%s = lshr i64 %e, 32
%t = trunc i64 %s to i32
ret i32 %t
}

define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
; LE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
Expand All @@ -34,6 +51,22 @@ define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
ret i32 %t
}

define i32 @vscale_shrinkShiftExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
; LE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
; LE-NEXT: [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 1
; LE-NEXT: ret i32 [[T]]
;
; BE-LABEL: @vscale_shrinkShiftExtractElt_i64_to_i32_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
; BE-NEXT: [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i64 0
; BE-NEXT: ret i32 [[T]]
;
%e = extractelement <vscale x 3 x i64> %x, i32 0
%s = lshr i64 %e, 32
%t = trunc i64 %s to i32
ret i32 %t
}

define i32 @shrinkExtractElt_i64_to_i32_1(<3 x i64> %x) {
; LE-LABEL: @shrinkExtractElt_i64_to_i32_1(
Expand Down Expand Up @@ -83,6 +116,23 @@ define i16 @shrinkExtractElt_i64_to_i16_0(<3 x i64> %x) {
ret i16 %t
}

define i16 @shrinkShiftExtractElt_i64_to_i16_0(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
; LE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 3
; LE-NEXT: ret i16 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_0(
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
; BE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 0
; BE-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i16 0
%s = ashr i64 %e, 48
%t = trunc i64 %s to i16
ret i16 %t
}

define i16 @shrinkExtractElt_i64_to_i16_1(<3 x i64> %x) {
; LE-LABEL: @shrinkExtractElt_i64_to_i16_1(
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <12 x i16>
Expand Down Expand Up @@ -157,6 +207,20 @@ define i30 @shrinkExtractElt_i40_to_i30_1(<3 x i40> %x) {
ret i30 %t
}

; Do not optimize if the shift amount isn't a whole number of truncated bits.
define i16 @shrinkShiftExtractElt_i64_to_i16_0_badshift(<3 x i64> %x) {
; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_0_badshift(
; ANY-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 0
; ANY-NEXT: [[S:%.*]] = lshr i64 [[E]], 31
; ANY-NEXT: [[T:%.*]] = trunc i64 [[S]] to i16
; ANY-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i16 0
%s = lshr i64 %e, 31
%t = trunc i64 %s to i16
ret i16 %t
}

; Do not canonicalize if that would increase the instruction count.
declare void @use(i64)
define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
Expand All @@ -172,6 +236,45 @@ define i16 @shrinkExtractElt_i64_to_i16_2_extra_use(<3 x i64> %x) {
ret i16 %t
}

; Do not canonicalize if that would increase the instruction count.
define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(<3 x i64> %x) {
; ANY-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_shift_use(
; ANY-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; ANY-NEXT: [[S:%.*]] = lshr i64 [[E]], 48
; ANY-NEXT: call void @use(i64 [[S]])
; ANY-NEXT: [[T:%.*]] = trunc nuw i64 [[S]] to i16
; ANY-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i64 2
%s = lshr i64 %e, 48
call void @use(i64 %s)
%t = trunc i64 %s to i16
ret i16 %t
}

; OK to reuse the extract if we remove the shift+trunc.
define i16 @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(<3 x i64> %x) {
; LE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
; LE-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; LE-NEXT: call void @use(i64 [[E]])
; LE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
; LE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 11
; LE-NEXT: ret i16 [[T]]
;
; BE-LABEL: @shrinkShiftExtractElt_i64_to_i16_2_extra_extract_use(
; BE-NEXT: [[E:%.*]] = extractelement <3 x i64> [[X:%.*]], i64 2
; BE-NEXT: call void @use(i64 [[E]])
; BE-NEXT: [[TMP1:%.*]] = bitcast <3 x i64> [[X]] to <12 x i16>
; BE-NEXT: [[T:%.*]] = extractelement <12 x i16> [[TMP1]], i64 8
; BE-NEXT: ret i16 [[T]]
;
%e = extractelement <3 x i64> %x, i64 2
call void @use(i64 %e)
%s = lshr i64 %e, 48
%t = trunc i64 %s to i16
ret i16 %t
}

; Check to ensure PR45314 remains fixed.
define <4 x i64> @PR45314(<4 x i64> %x) {
; LE-LABEL: @PR45314(
Expand Down
Loading