|
18 | 18 | #include "WebAssemblySubtarget.h" |
19 | 19 | #include "WebAssemblyTargetMachine.h" |
20 | 20 | #include "WebAssemblyUtilities.h" |
| 21 | +#include "llvm/ADT/ArrayRef.h" |
21 | 22 | #include "llvm/CodeGen/CallingConvLower.h" |
22 | 23 | #include "llvm/CodeGen/MachineFrameInfo.h" |
23 | 24 | #include "llvm/CodeGen/MachineInstrBuilder.h" |
@@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( |
91 | 92 | setOperationAction(ISD::LOAD, T, Custom); |
92 | 93 | setOperationAction(ISD::STORE, T, Custom); |
93 | 94 | } |
| 95 | + |
| 96 | + // Likewise, transform zext/sext/anyext extending loads from address space 1 |
| 97 | + // (WASM globals) |
| 98 | + setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32, |
| 99 | + {MVT::i8, MVT::i16}, Custom); |
| 100 | + setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64, |
| 101 | + {MVT::i8, MVT::i16, MVT::i32}, Custom); |
| 102 | + |
| 103 | + // Compensate for the EXTLOADs being custom by reimplementing some combiner |
| 104 | + // logic |
| 105 | + setTargetDAGCombine(ISD::AND); |
| 106 | + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); |
| 107 | + |
94 | 108 | if (Subtarget->hasSIMD128()) { |
95 | 109 | for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64, |
96 | 110 | MVT::v2f64}) { |
@@ -1683,6 +1697,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) { |
1683 | 1697 | if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op)) |
1684 | 1698 | return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace()); |
1685 | 1699 |
|
| 1700 | + if (Op->getOpcode() == WebAssemblyISD::Wrapper) |
| 1701 | + if (const GlobalAddressSDNode *GA = |
| 1702 | + dyn_cast<GlobalAddressSDNode>(Op->getOperand(0))) |
| 1703 | + return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace()); |
| 1704 | + |
1686 | 1705 | return false; |
1687 | 1706 | } |
1688 | 1707 |
|
@@ -1740,16 +1759,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op, |
1740 | 1759 | LoadSDNode *LN = cast<LoadSDNode>(Op.getNode()); |
1741 | 1760 | const SDValue &Base = LN->getBasePtr(); |
1742 | 1761 | const SDValue &Offset = LN->getOffset(); |
| 1762 | + ISD::LoadExtType ExtType = LN->getExtensionType(); |
| 1763 | + EVT ResultType = LN->getValueType(0); |
1743 | 1764 |
|
1744 | 1765 | if (IsWebAssemblyGlobal(Base)) { |
1745 | 1766 | if (!Offset->isUndef()) |
1746 | 1767 | report_fatal_error( |
1747 | 1768 | "unexpected offset when loading from webassembly global", false); |
1748 | 1769 |
|
1749 | | - SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other); |
1750 | | - SDValue Ops[] = {LN->getChain(), Base}; |
1751 | | - return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, |
1752 | | - LN->getMemoryVT(), LN->getMemOperand()); |
| 1770 | + if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) { |
| 1771 | + SDVTList Tys = DAG.getVTList(ResultType, MVT::Other); |
| 1772 | + SDValue Ops[] = {LN->getChain(), Base}; |
| 1773 | + SDValue GlobalGetNode = |
| 1774 | + DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, |
| 1775 | + LN->getMemoryVT(), LN->getMemOperand()); |
| 1776 | + return GlobalGetNode; |
| 1777 | + } |
| 1778 | + |
| 1779 | + EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE; |
| 1780 | + |
| 1781 | + if (auto *GA = dyn_cast<GlobalAddressSDNode>( |
| 1782 | + Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0) |
| 1783 | + : Base)) |
| 1784 | + GT = EVT::getEVT(GA->getGlobal()->getValueType()); |
| 1785 | + |
| 1786 | + if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 && |
| 1787 | + GT != MVT::f32 && GT != MVT::f64) |
| 1788 | + report_fatal_error("encountered unexpected global type for Base when " |
| 1789 | + "loading from webassembly global", |
| 1790 | + false); |
| 1791 | + |
| 1792 | + EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT); |
| 1793 | + |
| 1794 | + if (ExtType == ISD::NON_EXTLOAD) { |
| 1795 | + // A normal, non-extending load may try to load more or less than the |
| 1796 | + // underlying global, which is invalid. We lower this to a load of the |
| 1797 | + // global (i32 or i64) then truncate or extend as needed |
| 1798 | + |
| 1799 | + // Modify the MMO to load the full global |
| 1800 | + MachineMemOperand *OldMMO = LN->getMemOperand(); |
| 1801 | + MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( |
| 1802 | + OldMMO->getPointerInfo(), OldMMO->getFlags(), |
| 1803 | + LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(), |
| 1804 | + OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), |
| 1805 | + OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); |
| 1806 | + |
| 1807 | + SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other); |
| 1808 | + SDValue Ops[] = {LN->getChain(), Base}; |
| 1809 | + SDValue GlobalGetNode = DAG.getMemIntrinsicNode( |
| 1810 | + WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO); |
| 1811 | + |
| 1812 | + if (ResultType.bitsEq(PromotedGT)) { |
| 1813 | + return GlobalGetNode; |
| 1814 | + } |
| 1815 | + |
| 1816 | + SDValue ValRes; |
| 1817 | + if (ResultType.isFloatingPoint()) |
| 1818 | + ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType); |
| 1819 | + else |
| 1820 | + ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType); |
| 1821 | + |
| 1822 | + return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL); |
| 1823 | + } |
| 1824 | + |
| 1825 | + if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) { |
| 1826 | + // Turn the unsupported load into an EXTLOAD followed by an |
| 1827 | + // explicit zero/sign extend inreg. Same as Expand |
| 1828 | + |
| 1829 | + SDValue Result = |
| 1830 | + DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base, |
| 1831 | + LN->getMemoryVT(), LN->getMemOperand()); |
| 1832 | + SDValue ValRes; |
| 1833 | + if (ExtType == ISD::SEXTLOAD) |
| 1834 | + ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(), |
| 1835 | + Result, DAG.getValueType(LN->getMemoryVT())); |
| 1836 | + else |
| 1837 | + ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT()); |
| 1838 | + |
| 1839 | + return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); |
| 1840 | + } |
| 1841 | + |
| 1842 | + if (ExtType == ISD::EXTLOAD) { |
| 1843 | + // Expand the EXTLOAD into a regular LOAD of the global, and if |
| 1844 | + // needed, a zero-extension |
| 1845 | + |
| 1846 | + EVT OldLoadType = LN->getMemoryVT(); |
| 1847 | + EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType); |
| 1848 | + |
| 1849 | + // Modify the MMO to load a whole WASM "register"'s worth |
| 1850 | + MachineMemOperand *OldMMO = LN->getMemOperand(); |
| 1851 | + MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( |
| 1852 | + OldMMO->getPointerInfo(), OldMMO->getFlags(), |
| 1853 | + LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(), |
| 1854 | + OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), |
| 1855 | + OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); |
| 1856 | + |
| 1857 | + SDValue Result = |
| 1858 | + DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO); |
| 1859 | + |
| 1860 | + if (NewLoadType != ResultType) { |
| 1861 | + SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result); |
| 1862 | + return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); |
| 1863 | + } |
| 1864 | + |
| 1865 | + return Result; |
| 1866 | + } |
| 1867 | + |
| 1868 | + report_fatal_error( |
| 1869 | + "encountered unexpected ExtType when loading from webassembly global", |
| 1870 | + false); |
1753 | 1871 | } |
1754 | 1872 |
|
1755 | 1873 | if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) { |
@@ -3525,6 +3643,184 @@ static SDValue performMulCombine(SDNode *N, |
3525 | 3643 | } |
3526 | 3644 | } |
3527 | 3645 |
|
| 3646 | +static SDValue performANDCombine(SDNode *N, |
| 3647 | + TargetLowering::DAGCombinerInfo &DCI) { |
| 3648 | + // Copied and modified from DAGCombiner::visitAND(SDNode *N) |
| 3649 | + // We have to do this because the original combiner doesn't work when ZEXTLOAD |
| 3650 | + // has custom lowering |
| 3651 | + |
| 3652 | + SDValue N0 = N->getOperand(0); |
| 3653 | + SDValue N1 = N->getOperand(1); |
| 3654 | + SDLoc DL(N); |
| 3655 | + |
| 3656 | + // fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) -> |
| 3657 | + // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must |
| 3658 | + // already be zero by virtue of the width of the base type of the load. |
| 3659 | + // |
| 3660 | + // the 'X' node here can either be nothing or an extract_vector_elt to catch |
| 3661 | + // more cases. |
| 3662 | + if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && |
| 3663 | + N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() && |
| 3664 | + N0.getOperand(0).getOpcode() == ISD::LOAD && |
| 3665 | + N0.getOperand(0).getResNo() == 0) || |
| 3666 | + (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) { |
| 3667 | + auto *Load = |
| 3668 | + cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0)); |
| 3669 | + |
| 3670 | + // Get the constant (if applicable) the zero'th operand is being ANDed with. |
| 3671 | + // This can be a pure constant or a vector splat, in which case we treat the |
| 3672 | + // vector as a scalar and use the splat value. |
| 3673 | + APInt Constant = APInt::getZero(1); |
| 3674 | + if (const ConstantSDNode *C = isConstOrConstSplat( |
| 3675 | + N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) { |
| 3676 | + Constant = C->getAPIntValue(); |
| 3677 | + } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) { |
| 3678 | + unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits(); |
| 3679 | + APInt SplatValue, SplatUndef; |
| 3680 | + unsigned SplatBitSize; |
| 3681 | + bool HasAnyUndefs; |
| 3682 | + // Endianness should not matter here. Code below makes sure that we only |
| 3683 | + // use the result if the SplatBitSize is a multiple of the vector element |
| 3684 | + // size. And after that we AND all element sized parts of the splat |
| 3685 | + // together. So the end result should be the same regardless of in which |
| 3686 | + // order we do those operations. |
| 3687 | + const bool IsBigEndian = false; |
| 3688 | + bool IsSplat = |
| 3689 | + Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, |
| 3690 | + HasAnyUndefs, EltBitWidth, IsBigEndian); |
| 3691 | + |
| 3692 | + // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a |
| 3693 | + // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value. |
| 3694 | + if (IsSplat && (SplatBitSize % EltBitWidth) == 0) { |
| 3695 | + // Undef bits can contribute to a possible optimisation if set, so |
| 3696 | + // set them. |
| 3697 | + SplatValue |= SplatUndef; |
| 3698 | + |
| 3699 | + // The splat value may be something like "0x00FFFFFF", which means 0 for |
| 3700 | + // the first vector value and FF for the rest, repeating. We need a mask |
| 3701 | + // that will apply equally to all members of the vector, so AND all the |
| 3702 | + // lanes of the constant together. |
| 3703 | + Constant = APInt::getAllOnes(EltBitWidth); |
| 3704 | + for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i) |
| 3705 | + Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth); |
| 3706 | + } |
| 3707 | + } |
| 3708 | + |
| 3709 | + // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is |
| 3710 | + // actually legal and isn't going to get expanded, else this is a false |
| 3711 | + // optimisation. |
| 3712 | + |
| 3713 | + /*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, |
| 3714 | + Load->getValueType(0), |
| 3715 | + Load->getMemoryVT());*/ |
| 3716 | + // MODIFIED: this is the one difference in the logic; we allow ZEXT combine |
| 3717 | + // only in addrspace 0, where it's legal |
| 3718 | + bool CanZextLoadProfitably = Load->getAddressSpace() == 0; |
| 3719 | + |
| 3720 | + // Resize the constant to the same size as the original memory access before |
| 3721 | + // extension. If it is still the AllOnesValue then this AND is completely |
| 3722 | + // unneeded. |
| 3723 | + Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits()); |
| 3724 | + |
| 3725 | + bool B; |
| 3726 | + switch (Load->getExtensionType()) { |
| 3727 | + default: |
| 3728 | + B = false; |
| 3729 | + break; |
| 3730 | + case ISD::EXTLOAD: |
| 3731 | + B = CanZextLoadProfitably; |
| 3732 | + break; |
| 3733 | + case ISD::ZEXTLOAD: |
| 3734 | + case ISD::NON_EXTLOAD: |
| 3735 | + B = true; |
| 3736 | + break; |
| 3737 | + } |
| 3738 | + |
| 3739 | + if (B && Constant.isAllOnes()) { |
| 3740 | + // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to |
| 3741 | + // preserve semantics once we get rid of the AND. |
| 3742 | + SDValue NewLoad(Load, 0); |
| 3743 | + |
| 3744 | + // Fold the AND away. NewLoad may get replaced immediately. |
| 3745 | + DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0); |
| 3746 | + |
| 3747 | + if (Load->getExtensionType() == ISD::EXTLOAD) { |
| 3748 | + NewLoad = DCI.DAG.getLoad( |
| 3749 | + Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0), |
| 3750 | + SDLoc(Load), Load->getChain(), Load->getBasePtr(), |
| 3751 | + Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand()); |
| 3752 | + // Replace uses of the EXTLOAD with the new ZEXTLOAD. |
| 3753 | + if (Load->getNumValues() == 3) { |
| 3754 | + // PRE/POST_INC loads have 3 values. |
| 3755 | + SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1), |
| 3756 | + NewLoad.getValue(2)}; |
| 3757 | + DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true); |
| 3758 | + } else { |
| 3759 | + DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1)); |
| 3760 | + } |
| 3761 | + } |
| 3762 | + |
| 3763 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3764 | + } |
| 3765 | + } |
| 3766 | + return SDValue(); |
| 3767 | +} |
| 3768 | + |
| 3769 | +static SDValue |
| 3770 | +performSIGN_EXTEND_INREGCombine(SDNode *N, |
| 3771 | + TargetLowering::DAGCombinerInfo &DCI) { |
| 3772 | + // Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) |
| 3773 | + // We have to do this because the original combiner doesn't work when SEXTLOAD |
| 3774 | + // has custom lowering |
| 3775 | + |
| 3776 | + SDValue N0 = N->getOperand(0); |
| 3777 | + SDValue N1 = N->getOperand(1); |
| 3778 | + EVT VT = N->getValueType(0); |
| 3779 | + EVT ExtVT = cast<VTSDNode>(N1)->getVT(); |
| 3780 | + SDLoc DL(N); |
| 3781 | + |
| 3782 | + // fold (sext_inreg (extload x)) -> (sextload x) |
| 3783 | + // If sextload is not supported by target, we can only do the combine when |
| 3784 | + // load has one use. Doing otherwise can block folding the extload with other |
| 3785 | + // extends that the target does support. |
| 3786 | + |
| 3787 | + // MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with |
| 3788 | + // cast<LoadSDNode>(N0)->getAddressSpace() == 0) |
| 3789 | + if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && |
| 3790 | + ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && |
| 3791 | + ((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() && |
| 3792 | + N0.hasOneUse()) || |
| 3793 | + cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { |
| 3794 | + auto *LN0 = cast<LoadSDNode>(N0); |
| 3795 | + SDValue ExtLoad = |
| 3796 | + DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), |
| 3797 | + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); |
| 3798 | + DCI.CombineTo(N, ExtLoad); |
| 3799 | + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); |
| 3800 | + DCI.AddToWorklist(ExtLoad.getNode()); |
| 3801 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3802 | + } |
| 3803 | + |
| 3804 | + // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use |
| 3805 | + |
| 3806 | + // MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with |
| 3807 | + // cast<LoadSDNode>(N0)->getAddressSpace() == 0) |
| 3808 | + if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && |
| 3809 | + N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && |
| 3810 | + ((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) && |
| 3811 | + cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { |
| 3812 | + auto *LN0 = cast<LoadSDNode>(N0); |
| 3813 | + SDValue ExtLoad = |
| 3814 | + DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), |
| 3815 | + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); |
| 3816 | + DCI.CombineTo(N, ExtLoad); |
| 3817 | + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); |
| 3818 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3819 | + } |
| 3820 | + |
| 3821 | + return SDValue(); |
| 3822 | +} |
| 3823 | + |
3528 | 3824 | SDValue |
3529 | 3825 | WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, |
3530 | 3826 | DAGCombinerInfo &DCI) const { |
@@ -3557,5 +3853,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, |
3557 | 3853 | return performAnyAllCombine(N, DCI.DAG); |
3558 | 3854 | case ISD::MUL: |
3559 | 3855 | return performMulCombine(N, DCI); |
| 3856 | + case ISD::AND: |
| 3857 | + return performANDCombine(N, DCI); |
| 3858 | + case ISD::SIGN_EXTEND_INREG: |
| 3859 | + return performSIGN_EXTEND_INREGCombine(N, DCI); |
3560 | 3860 | } |
3561 | 3861 | } |
0 commit comments