Skip to content

Commit fe9d88a

Browse files
Fix lowering of loads (and extending loads) from addrspace(1) globals
1 parent 4db3809 commit fe9d88a

File tree

4 files changed

+353
-45
lines changed

4 files changed

+353
-45
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 304 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/ADT/ArrayRef.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
9192
setOperationAction(ISD::LOAD, T, Custom);
9293
setOperationAction(ISD::STORE, T, Custom);
9394
}
95+
96+
// Likewise, transform zext/sext/anyext extending loads from address space 1
97+
// (WASM globals)
98+
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32,
99+
{MVT::i8, MVT::i16}, Custom);
100+
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64,
101+
{MVT::i8, MVT::i16, MVT::i32}, Custom);
102+
103+
// Compensate for the EXTLOADs being custom by reimplementing some combiner
104+
// logic
105+
setTargetDAGCombine(ISD::AND);
106+
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
107+
94108
if (Subtarget->hasSIMD128()) {
95109
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
96110
MVT::v2f64}) {
@@ -1683,6 +1697,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) {
16831697
if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
16841698
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
16851699

1700+
if (Op->getOpcode() == WebAssemblyISD::Wrapper)
1701+
if (const GlobalAddressSDNode *GA =
1702+
dyn_cast<GlobalAddressSDNode>(Op->getOperand(0)))
1703+
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
1704+
16861705
return false;
16871706
}
16881707

@@ -1740,16 +1759,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
17401759
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
17411760
const SDValue &Base = LN->getBasePtr();
17421761
const SDValue &Offset = LN->getOffset();
1762+
ISD::LoadExtType ExtType = LN->getExtensionType();
1763+
EVT ResultType = LN->getValueType(0);
17431764

17441765
if (IsWebAssemblyGlobal(Base)) {
17451766
if (!Offset->isUndef())
17461767
report_fatal_error(
17471768
"unexpected offset when loading from webassembly global", false);
17481769

1749-
SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other);
1750-
SDValue Ops[] = {LN->getChain(), Base};
1751-
return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
1752-
LN->getMemoryVT(), LN->getMemOperand());
1770+
if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) {
1771+
SDVTList Tys = DAG.getVTList(ResultType, MVT::Other);
1772+
SDValue Ops[] = {LN->getChain(), Base};
1773+
SDValue GlobalGetNode =
1774+
DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
1775+
LN->getMemoryVT(), LN->getMemOperand());
1776+
return GlobalGetNode;
1777+
}
1778+
1779+
EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE;
1780+
1781+
if (auto *GA = dyn_cast<GlobalAddressSDNode>(
1782+
Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0)
1783+
: Base))
1784+
GT = EVT::getEVT(GA->getGlobal()->getValueType());
1785+
1786+
if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 &&
1787+
GT != MVT::f32 && GT != MVT::f64)
1788+
report_fatal_error("encountered unexpected global type for Base when "
1789+
"loading from webassembly global",
1790+
false);
1791+
1792+
EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT);
1793+
1794+
if (ExtType == ISD::NON_EXTLOAD) {
1795+
// A normal, non-extending load may try to load more or less than the
1796+
// underlying global, which is invalid. We lower this to a load of the
1797+
// global (i32 or i64) then truncate or extend as needed
1798+
1799+
// Modify the MMO to load the full global
1800+
MachineMemOperand *OldMMO = LN->getMemOperand();
1801+
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
1802+
OldMMO->getPointerInfo(), OldMMO->getFlags(),
1803+
LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(),
1804+
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
1805+
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
1806+
1807+
SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other);
1808+
SDValue Ops[] = {LN->getChain(), Base};
1809+
SDValue GlobalGetNode = DAG.getMemIntrinsicNode(
1810+
WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO);
1811+
1812+
if (ResultType.bitsEq(PromotedGT)) {
1813+
return GlobalGetNode;
1814+
}
1815+
1816+
SDValue ValRes;
1817+
if (ResultType.isFloatingPoint())
1818+
ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType);
1819+
else
1820+
ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType);
1821+
1822+
return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL);
1823+
}
1824+
1825+
if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) {
1826+
// Turn the unsupported load into an EXTLOAD followed by an
1827+
// explicit zero/sign extend inreg. Same as Expand
1828+
1829+
SDValue Result =
1830+
DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base,
1831+
LN->getMemoryVT(), LN->getMemOperand());
1832+
SDValue ValRes;
1833+
if (ExtType == ISD::SEXTLOAD)
1834+
ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(),
1835+
Result, DAG.getValueType(LN->getMemoryVT()));
1836+
else
1837+
ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT());
1838+
1839+
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
1840+
}
1841+
1842+
if (ExtType == ISD::EXTLOAD) {
1843+
// Expand the EXTLOAD into a regular LOAD of the global, and if
1844+
// needed, a zero-extension
1845+
1846+
EVT OldLoadType = LN->getMemoryVT();
1847+
EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType);
1848+
1849+
// Modify the MMO to load a whole WASM "register"'s worth
1850+
MachineMemOperand *OldMMO = LN->getMemOperand();
1851+
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
1852+
OldMMO->getPointerInfo(), OldMMO->getFlags(),
1853+
LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(),
1854+
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
1855+
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
1856+
1857+
SDValue Result =
1858+
DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO);
1859+
1860+
if (NewLoadType != ResultType) {
1861+
SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result);
1862+
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
1863+
}
1864+
1865+
return Result;
1866+
}
1867+
1868+
report_fatal_error(
1869+
"encountered unexpected ExtType when loading from webassembly global",
1870+
false);
17531871
}
17541872

17551873
if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
@@ -3525,6 +3643,184 @@ static SDValue performMulCombine(SDNode *N,
35253643
}
35263644
}
35273645

3646+
static SDValue performANDCombine(SDNode *N,
3647+
TargetLowering::DAGCombinerInfo &DCI) {
3648+
// Copied and modified from DAGCombiner::visitAND(SDNode *N)
3649+
// We have to do this because the original combiner doesn't work when ZEXTLOAD
3650+
// has custom lowering
3651+
3652+
SDValue N0 = N->getOperand(0);
3653+
SDValue N1 = N->getOperand(1);
3654+
SDLoc DL(N);
3655+
3656+
// fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
3657+
// (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
3658+
// already be zero by virtue of the width of the base type of the load.
3659+
//
3660+
// the 'X' node here can either be nothing or an extract_vector_elt to catch
3661+
// more cases.
3662+
if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
3663+
N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
3664+
N0.getOperand(0).getOpcode() == ISD::LOAD &&
3665+
N0.getOperand(0).getResNo() == 0) ||
3666+
(N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
3667+
auto *Load =
3668+
cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
3669+
3670+
// Get the constant (if applicable) the zero'th operand is being ANDed with.
3671+
// This can be a pure constant or a vector splat, in which case we treat the
3672+
// vector as a scalar and use the splat value.
3673+
APInt Constant = APInt::getZero(1);
3674+
if (const ConstantSDNode *C = isConstOrConstSplat(
3675+
N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
3676+
Constant = C->getAPIntValue();
3677+
} else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
3678+
unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
3679+
APInt SplatValue, SplatUndef;
3680+
unsigned SplatBitSize;
3681+
bool HasAnyUndefs;
3682+
// Endianness should not matter here. Code below makes sure that we only
3683+
// use the result if the SplatBitSize is a multiple of the vector element
3684+
// size. And after that we AND all element sized parts of the splat
3685+
// together. So the end result should be the same regardless of in which
3686+
// order we do those operations.
3687+
const bool IsBigEndian = false;
3688+
bool IsSplat =
3689+
Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
3690+
HasAnyUndefs, EltBitWidth, IsBigEndian);
3691+
3692+
// Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
3693+
// multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
3694+
if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
3695+
// Undef bits can contribute to a possible optimisation if set, so
3696+
// set them.
3697+
SplatValue |= SplatUndef;
3698+
3699+
// The splat value may be something like "0x00FFFFFF", which means 0 for
3700+
// the first vector value and FF for the rest, repeating. We need a mask
3701+
// that will apply equally to all members of the vector, so AND all the
3702+
// lanes of the constant together.
3703+
Constant = APInt::getAllOnes(EltBitWidth);
3704+
for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
3705+
Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
3706+
}
3707+
}
3708+
3709+
// If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
3710+
// actually legal and isn't going to get expanded, else this is a false
3711+
// optimisation.
3712+
3713+
/*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
3714+
Load->getValueType(0),
3715+
Load->getMemoryVT());*/
3716+
// MODIFIED: this is the one difference in the logic; we allow ZEXT combine
3717+
// only in addrspace 0, where it's legal
3718+
bool CanZextLoadProfitably = Load->getAddressSpace() == 0;
3719+
3720+
// Resize the constant to the same size as the original memory access before
3721+
// extension. If it is still the AllOnesValue then this AND is completely
3722+
// unneeded.
3723+
Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
3724+
3725+
bool B;
3726+
switch (Load->getExtensionType()) {
3727+
default:
3728+
B = false;
3729+
break;
3730+
case ISD::EXTLOAD:
3731+
B = CanZextLoadProfitably;
3732+
break;
3733+
case ISD::ZEXTLOAD:
3734+
case ISD::NON_EXTLOAD:
3735+
B = true;
3736+
break;
3737+
}
3738+
3739+
if (B && Constant.isAllOnes()) {
3740+
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
3741+
// preserve semantics once we get rid of the AND.
3742+
SDValue NewLoad(Load, 0);
3743+
3744+
// Fold the AND away. NewLoad may get replaced immediately.
3745+
DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
3746+
3747+
if (Load->getExtensionType() == ISD::EXTLOAD) {
3748+
NewLoad = DCI.DAG.getLoad(
3749+
Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0),
3750+
SDLoc(Load), Load->getChain(), Load->getBasePtr(),
3751+
Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand());
3752+
// Replace uses of the EXTLOAD with the new ZEXTLOAD.
3753+
if (Load->getNumValues() == 3) {
3754+
// PRE/POST_INC loads have 3 values.
3755+
SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1),
3756+
NewLoad.getValue(2)};
3757+
DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true);
3758+
} else {
3759+
DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
3760+
}
3761+
}
3762+
3763+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3764+
}
3765+
}
3766+
return SDValue();
3767+
}
3768+
3769+
static SDValue
3770+
performSIGN_EXTEND_INREGCombine(SDNode *N,
3771+
TargetLowering::DAGCombinerInfo &DCI) {
3772+
// Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N)
3773+
// We have to do this because the original combiner doesn't work when SEXTLOAD
3774+
// has custom lowering
3775+
3776+
SDValue N0 = N->getOperand(0);
3777+
SDValue N1 = N->getOperand(1);
3778+
EVT VT = N->getValueType(0);
3779+
EVT ExtVT = cast<VTSDNode>(N1)->getVT();
3780+
SDLoc DL(N);
3781+
3782+
// fold (sext_inreg (extload x)) -> (sextload x)
3783+
// If sextload is not supported by target, we can only do the combine when
3784+
// load has one use. Doing otherwise can block folding the extload with other
3785+
// extends that the target does support.
3786+
3787+
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
3788+
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
3789+
if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
3790+
ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
3791+
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() &&
3792+
N0.hasOneUse()) ||
3793+
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
3794+
auto *LN0 = cast<LoadSDNode>(N0);
3795+
SDValue ExtLoad =
3796+
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
3797+
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
3798+
DCI.CombineTo(N, ExtLoad);
3799+
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
3800+
DCI.AddToWorklist(ExtLoad.getNode());
3801+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3802+
}
3803+
3804+
// fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
3805+
3806+
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
3807+
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
3808+
if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
3809+
N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
3810+
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) &&
3811+
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
3812+
auto *LN0 = cast<LoadSDNode>(N0);
3813+
SDValue ExtLoad =
3814+
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
3815+
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
3816+
DCI.CombineTo(N, ExtLoad);
3817+
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
3818+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3819+
}
3820+
3821+
return SDValue();
3822+
}
3823+
35283824
SDValue
35293825
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35303826
DAGCombinerInfo &DCI) const {
@@ -3557,5 +3853,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35573853
return performAnyAllCombine(N, DCI.DAG);
35583854
case ISD::MUL:
35593855
return performMulCombine(N, DCI);
3856+
case ISD::AND:
3857+
return performANDCombine(N, DCI);
3858+
case ISD::SIGN_EXTEND_INREG:
3859+
return performSIGN_EXTEND_INREGCombine(N, DCI);
35603860
}
35613861
}

0 commit comments

Comments
 (0)