2525#include " llvm/IR/PassManager.h"
2626#include " llvm/IR/Type.h"
2727#include " llvm/Pass.h"
28+ #include " llvm/Support/Casting.h"
2829#include " llvm/Support/ErrorHandling.h"
2930#include " llvm/Support/MathExtras.h"
3031
@@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
7071 case Intrinsic::vector_reduce_add:
7172 case Intrinsic::vector_reduce_fadd:
7273 return true ;
73- case Intrinsic::dx_resource_load_typedbuffer:
74- // We need to handle doubles and vector of doubles.
75- return F.getReturnType ()
76- ->getStructElementType (0 )
77- ->getScalarType ()
78- ->isDoubleTy ();
79- case Intrinsic::dx_resource_store_typedbuffer:
80- // We need to handle doubles and vector of doubles.
81- return F.getFunctionType ()->getParamType (2 )->getScalarType ()->isDoubleTy ();
74+ case Intrinsic::dx_resource_load_typedbuffer: {
75+ // We need to handle i64, doubles, and vectors of them.
76+ Type *ScalarTy =
77+ F.getReturnType ()->getStructElementType (0 )->getScalarType ();
78+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
79+ }
80+ case Intrinsic::dx_resource_store_typedbuffer: {
81+ // We need to handle i64 and doubles and vectors of i64 and doubles.
82+ Type *ScalarTy = F.getFunctionType ()->getParamType (2 )->getScalarType ();
83+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
84+ }
8285 }
8386 return false ;
8487}
@@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
545548 IRBuilder<> Builder (Orig);
546549
547550 Type *BufferTy = Orig->getType ()->getStructElementType (0 );
548- assert (BufferTy->getScalarType ()->isDoubleTy () &&
549- " Only expand double or double2" );
551+ Type *ScalarTy = BufferTy->getScalarType ();
552+ bool IsDouble = ScalarTy->isDoubleTy ();
553+ assert (IsDouble || ScalarTy->isIntegerTy (64 ) &&
554+ " Only expand double or int64 scalars or vectors" );
550555
551556 unsigned ExtractNum = 2 ;
552557 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
553558 assert (VT->getNumElements () == 2 &&
554- " TypedBufferLoad double vector has wrong size" );
559+ " TypedBufferLoad vector must be size 2 " );
555560 ExtractNum = 4 ;
556561 }
557562
@@ -570,22 +575,42 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
570575 ExtractElements.push_back (
571576 Builder.CreateExtractElement (Extract, Builder.getInt32 (I)));
572577
573- // combine into double(s)
578+ // combine into double(s) or int64(s)
574579 Value *Result = PoisonValue::get (BufferTy);
575580 for (unsigned I = 0 ; I < ExtractNum; I += 2 ) {
576- Value *Dbl =
577- Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
578- {ExtractElements[I], ExtractElements[I + 1 ]});
581+ Value *Combined = nullptr ;
582+ if (IsDouble)
583+ // For doubles, use dx_asdouble intrinsic
584+ Combined =
585+ Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
586+ {ExtractElements[I], ExtractElements[I + 1 ]});
587+ else {
588+ // For int64, manually combine two int32s
589+ // First, zero-extend both values to i64
590+ Value *Lo = Builder.CreateZExt (ExtractElements[I], Builder.getInt64Ty ());
591+ Value *Hi =
592+ Builder.CreateZExt (ExtractElements[I + 1 ], Builder.getInt64Ty ());
593+ // Shift the high bits left by 32 bits
594+ Value *ShiftedHi = Builder.CreateShl (Hi, Builder.getInt64 (32 ));
595+ // OR the high and low bits together
596+ Combined = Builder.CreateOr (Lo, ShiftedHi);
597+ }
598+
579599 if (ExtractNum == 4 )
580- Result =
581- Builder. CreateInsertElement (Result, Dbl, Builder.getInt32 (I / 2 ));
600+ Result = Builder. CreateInsertElement (Result, Combined,
601+ Builder.getInt32 (I / 2 ));
582602 else
583- Result = Dbl ;
603+ Result = Combined ;
584604 }
585605
586606 Value *CheckBit = nullptr ;
587607 for (User *U : make_early_inc_range (Orig->users ())) {
588- auto *EVI = cast<ExtractValueInst>(U);
608+ // If it's not a ExtractValueInst, we don't know how to
609+ // handle it
610+ auto *EVI = dyn_cast<ExtractValueInst>(U);
611+ if (!EVI)
612+ llvm_unreachable (" Unexpected user of typedbufferload" );
613+
589614 ArrayRef<unsigned > Indices = EVI->getIndices ();
590615 assert (Indices.size () == 1 );
591616
@@ -609,38 +634,61 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
609634 IRBuilder<> Builder (Orig);
610635
611636 Type *BufferTy = Orig->getFunctionType ()->getParamType (2 );
612- assert (BufferTy->getScalarType ()->isDoubleTy () &&
613- " Only expand double or double2" );
614-
615- unsigned ExtractNum = 2 ;
616- if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
617- assert (VT->getNumElements () == 2 &&
618- " TypedBufferStore double vector has wrong size" );
619- ExtractNum = 4 ;
637+ Type *ScalarTy = BufferTy->getScalarType ();
638+ bool IsDouble = ScalarTy->isDoubleTy ();
639+ assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
640+ " Only expand double or int64 scalars or vectors" );
641+
642+ // Determine if we're dealing with a vector or scalar
643+ bool IsVector = isa<FixedVectorType>(BufferTy);
644+ if (IsVector) {
645+ assert (cast<FixedVectorType>(BufferTy)->getNumElements () == 2 &&
646+ " TypedBufferStore vector must be size 2" );
620647 }
621648
622- Type *SplitElementTy = Builder.getInt32Ty ();
623- if (ExtractNum == 4 )
649+ // Create the appropriate vector type for the result
650+ Type *Int32Ty = Builder.getInt32Ty ();
651+ Type *ResultTy = VectorType::get (Int32Ty, IsVector ? 4 : 2 , false );
652+ Value *Val = PoisonValue::get (ResultTy);
653+
654+ Type *SplitElementTy = Int32Ty;
655+ if (IsVector)
624656 SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
625657
626- // split our double(s)
627- auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
628- Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
629- Orig->getOperand (2 ));
630- // create our vector
631- Value *LowBits = Builder.CreateExtractValue (Split, 0 );
632- Value *HighBits = Builder.CreateExtractValue (Split, 1 );
633- Value *Val;
634- if (ExtractNum == 2 ) {
635- Val = PoisonValue::get (VectorType::get (SplitElementTy, 2 , false ));
658+ Value *LowBits = nullptr ;
659+ Value *HighBits = nullptr ;
660+ // Split the 64-bit values into 32-bit components
661+ if (IsDouble) {
662+ auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
663+ Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
664+ {Orig->getOperand (2 )});
665+ LowBits = Builder.CreateExtractValue (Split, 0 );
666+ HighBits = Builder.CreateExtractValue (Split, 1 );
667+ } else {
668+ // Handle int64 type(s)
669+ Value *InputVal = Orig->getOperand (2 );
670+ Constant *ShiftAmt = Builder.getInt64 (32 );
671+ if (IsVector)
672+ ShiftAmt = ConstantVector::getSplat (ElementCount::getFixed (2 ), ShiftAmt);
673+
674+ // Split into low and high 32-bit parts
675+ LowBits = Builder.CreateTrunc (InputVal, SplitElementTy);
676+ Value *ShiftedVal = Builder.CreateLShr (InputVal, ShiftAmt);
677+ HighBits = Builder.CreateTrunc (ShiftedVal, SplitElementTy);
678+ }
679+
680+ if (IsVector) {
681+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
682+ } else {
636683 Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
637684 Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
638- } else
639- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
685+ }
640686
687+ // Create the final intrinsic call
641688 Builder.CreateIntrinsic (Builder.getVoidTy (),
642689 Intrinsic::dx_resource_store_typedbuffer,
643690 {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
691+
644692 Orig->eraseFromParent ();
645693 return true ;
646694}
0 commit comments