Skip to content

Commit 567b55f

Browse files
committed
[AMDGPU] Support D16 folding for image.sample with multiple extractelement and fptrunc users
1 parent 5c37840 commit 567b55f

File tree

2 files changed

+477
-2
lines changed

2 files changed

+477
-2
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,68 @@ simplifyAMDGCNImageIntrinsic(const GCNSubtarget *ST,
269269
ArgTys[0] = User->getType();
270270
});
271271
}
272+
} else {
273+
// Only perform D16 folding if every user of the image sample is
274+
// an ExtractElementInst immediately followed by an FPTrunc to half.
275+
SmallVector<ExtractElementInst *, 4> Extracts;
276+
SmallVector<FPTruncInst *, 4> Truncs;
277+
bool AllHalfExtracts = true;
278+
279+
for (User *U : II.users()) {
280+
auto *Ext = dyn_cast<ExtractElementInst>(U);
281+
if (!Ext || !Ext->hasOneUse()) {
282+
AllHalfExtracts = false;
283+
break;
284+
}
285+
auto *Tr = dyn_cast<FPTruncInst>(*Ext->user_begin());
286+
if (!Tr || !Tr->getType()->getScalarType()->isHalfTy()) {
287+
AllHalfExtracts = false;
288+
break;
289+
}
290+
Extracts.push_back(Ext);
291+
Truncs.push_back(Tr);
292+
}
293+
294+
if (AllHalfExtracts && !Extracts.empty()) {
295+
auto *VecTy = cast<VectorType>(II.getType());
296+
unsigned NElts = VecTy->getElementCount().getKnownMinValue();
297+
Type *HalfVecTy =
298+
VectorType::get(Type::getHalfTy(II.getContext()), NElts, false);
299+
300+
// Obtain the original image sample intrinsic's signature
301+
// and replace its return type with the half-vector for D16 folding
302+
SmallVector<Type *, 8> SigTys;
303+
if (!Intrinsic::getIntrinsicSignature(II.getCalledFunction(), SigTys))
304+
return nullptr;
305+
SigTys[0] = HalfVecTy;
306+
307+
Module *M = II.getModule();
308+
Function *HalfDecl =
309+
Intrinsic::getOrInsertDeclaration(M, ImageDimIntr->Intr, SigTys);
310+
311+
II.mutateType(HalfVecTy);
312+
II.setCalledFunction(HalfDecl);
313+
314+
IRBuilder<> Builder(&II);
315+
for (auto [lane, Ext] : enumerate(Extracts)) {
316+
FPTruncInst *Tr = Truncs[lane];
317+
Value *Idx = Ext->getIndexOperand();
318+
319+
Builder.SetInsertPoint(Tr);
320+
321+
Value *HalfExtract = Builder.CreateExtractElement(&II, Idx);
322+
HalfExtract->takeName(Tr);
323+
324+
Tr->replaceAllUsesWith(HalfExtract);
325+
}
326+
327+
for (auto *T : Truncs)
328+
IC.eraseInstFromFunction(*T);
329+
for (auto *E : Extracts)
330+
IC.eraseInstFromFunction(*E);
331+
332+
return &II;
333+
}
272334
}
273335
}
274336
}

0 commit comments

Comments
 (0)