1818
1919#include " support/dtypes.h"
2020
21- #include < llvm/Pass.h>
2221#include < llvm/IR/IRBuilder.h>
2322#include < llvm/IR/LegacyPassManager.h>
2423#include < llvm/IR/PassManager.h>
@@ -29,193 +28,15 @@ using namespace llvm;
2928
3029namespace {
3130
32- inline AttributeSet getFnAttrs (const AttributeList &Attrs)
33- {
34- #if JL_LLVM_VERSION >= 140000
35- return Attrs.getFnAttrs ();
36- #else
37- return Attrs.getFnAttributes ();
38- #endif
39- }
40-
41- inline AttributeSet getRetAttrs (const AttributeList &Attrs)
42- {
43- #if JL_LLVM_VERSION >= 140000
44- return Attrs.getRetAttrs ();
45- #else
46- return Attrs.getRetAttributes ();
47- #endif
48- }
49-
50- static Instruction *replaceIntrinsicWith (IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
51- {
52- Intrinsic::ID ID = call->getIntrinsicID ();
53- assert (ID);
54- auto oldfType = call->getFunctionType ();
55- auto nargs = oldfType->getNumParams ();
56- assert (args.size () > nargs);
57- SmallVector<Type*, 8 > argTys (nargs);
58- for (unsigned i = 0 ; i < nargs; i++)
59- argTys[i] = args[i]->getType ();
60- auto newfType = FunctionType::get (RetTy, argTys, oldfType->isVarArg ());
61-
62- // Accumulate an array of overloaded types for the given intrinsic
63- // and compute the new name mangling schema
64- SmallVector<Type*, 4 > overloadTys;
65- {
66- SmallVector<Intrinsic::IITDescriptor, 8 > Table;
67- getIntrinsicInfoTableEntries (ID, Table);
68- ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
69- auto res = Intrinsic::matchIntrinsicSignature (newfType, TableRef, overloadTys);
70- assert (res == Intrinsic::MatchIntrinsicTypes_Match);
71- (void )res;
72- bool matchvararg = !Intrinsic::matchIntrinsicVarArg (newfType->isVarArg (), TableRef);
73- assert (matchvararg);
74- (void )matchvararg;
75- }
76- auto newF = Intrinsic::getDeclaration (call->getModule (), ID, overloadTys);
77- assert (newF->getFunctionType () == newfType);
78- newF->setCallingConv (call->getCallingConv ());
79- assert (args.back () == call->getCalledFunction ());
80- auto newCall = CallInst::Create (newF, args.drop_back (), " " , call);
81- newCall->setTailCallKind (call->getTailCallKind ());
82- auto old_attrs = call->getAttributes ();
83- newCall->setAttributes (AttributeList::get (call->getContext (), getFnAttrs (old_attrs),
84- getRetAttrs (old_attrs), {})); // drop parameter attributes
85- return newCall;
86- }
87-
88-
89- static Value* CreateFPCast (Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
90- {
91- Type *SrcTy = V->getType ();
92- Type *RetTy = DestTy;
93- if (auto *VC = dyn_cast<Constant>(V)) {
94- // The input IR often has things of the form
95- // fcmp olt half %0, 0xH7C00
96- // and we would like to avoid turning that constant into a call here
97- // if we can simply constant fold it to the new type.
98- VC = ConstantExpr::getCast (opcode, VC, DestTy, true );
99- if (VC)
100- return VC;
101- }
102- assert (SrcTy->isVectorTy () == DestTy->isVectorTy ());
103- if (SrcTy->isVectorTy ()) {
104- unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements ();
105- assert (cast<FixedVectorType>(DestTy)->getNumElements () == NumElems && " Mismatched cast" );
106- Value *NewV = UndefValue::get (DestTy);
107- RetTy = RetTy->getScalarType ();
108- for (unsigned i = 0 ; i < NumElems; ++i) {
109- Value *I = builder.getInt32 (i);
110- Value *Vi = builder.CreateExtractElement (V, I);
111- Vi = CreateFPCast (opcode, Vi, RetTy, builder);
112- NewV = builder.CreateInsertElement (NewV, Vi, I);
113- }
114- return NewV;
115- }
116- auto &M = *builder.GetInsertBlock ()->getModule ();
117- auto &ctx = M.getContext ();
118- // Pick the Function to call in the Julia runtime
119- StringRef Name;
120- switch (opcode) {
121- case Instruction::FPExt:
122- // this is exact, so we only need one conversion
123- assert (SrcTy->isHalfTy ());
124- Name = " julia__gnu_h2f_ieee" ;
125- RetTy = Type::getFloatTy (ctx);
126- break ;
127- case Instruction::FPTrunc:
128- assert (DestTy->isHalfTy ());
129- if (SrcTy->isFloatTy ())
130- Name = " julia__gnu_f2h_ieee" ;
131- else if (SrcTy->isDoubleTy ())
132- Name = " julia__truncdfhf2" ;
133- break ;
134- // All F16 fit exactly in Int32 (-65504 to 65504)
135- case Instruction::FPToSI: JL_FALLTHROUGH;
136- case Instruction::FPToUI:
137- assert (SrcTy->isHalfTy ());
138- Name = " julia__gnu_h2f_ieee" ;
139- RetTy = Type::getFloatTy (ctx);
140- break ;
141- case Instruction::SIToFP: JL_FALLTHROUGH;
142- case Instruction::UIToFP:
143- assert (DestTy->isHalfTy ());
144- Name = " julia__gnu_f2h_ieee" ;
145- SrcTy = Type::getFloatTy (ctx);
146- break ;
147- default :
148- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
149- V->getType ()->print (errs ());
150- errs () << " to " ;
151- DestTy->print (errs ());
152- errs () << " is an " ;
153- llvm_unreachable (" invalid cast" );
154- }
155- if (Name.empty ()) {
156- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
157- V->getType ()->print (errs ());
158- errs () << " to " ;
159- DestTy->print (errs ());
160- errs () << " is an " ;
161- llvm_unreachable (" illegal cast" );
162- }
163- // Coerce the source to the required size and type
164- auto T_int16 = Type::getInt16Ty (ctx);
165- if (SrcTy->isHalfTy ())
166- SrcTy = T_int16;
167- if (opcode == Instruction::SIToFP)
168- V = builder.CreateSIToFP (V, SrcTy);
169- else if (opcode == Instruction::UIToFP)
170- V = builder.CreateUIToFP (V, SrcTy);
171- else
172- V = builder.CreateBitCast (V, SrcTy);
173- // Call our intrinsic
174- if (RetTy->isHalfTy ())
175- RetTy = T_int16;
176- auto FT = FunctionType::get (RetTy, {SrcTy}, false );
177- FunctionCallee F = M.getOrInsertFunction (Name, FT);
178- Value *I = builder.CreateCall (F, {V});
179- // Coerce the result to the expected type
180- if (opcode == Instruction::FPToSI)
181- I = builder.CreateFPToSI (I, DestTy);
182- else if (opcode == Instruction::FPToUI)
183- I = builder.CreateFPToUI (I, DestTy);
184- else if (opcode == Instruction::FPExt)
185- I = builder.CreateFPCast (I, DestTy);
186- else
187- I = builder.CreateBitCast (I, DestTy);
188- return I;
189- }
190-
19131static bool demoteFloat16 (Function &F)
19232{
19333 auto &ctx = F.getContext ();
34+ auto T_float16 = Type::getHalfTy (ctx);
19435 auto T_float32 = Type::getFloatTy (ctx);
19536
19637 SmallVector<Instruction *, 0 > erase;
19738 for (auto &BB : F) {
19839 for (auto &I : BB) {
199- // extend Float16 operands to Float32
200- bool Float16 = I.getType ()->getScalarType ()->isHalfTy ();
201- for (size_t i = 0 ; !Float16 && i < I.getNumOperands (); i++) {
202- Value *Op = I.getOperand (i);
203- if (Op->getType ()->getScalarType ()->isHalfTy ())
204- Float16 = true ;
205- }
206- if (!Float16)
207- continue ;
208-
209- if (auto CI = dyn_cast<CastInst>(&I)) {
210- if (CI->getOpcode () != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
211- IRBuilder<> builder (&I);
212- Value *NewI = CreateFPCast (CI->getOpcode (), I.getOperand (0 ), I.getType (), builder);
213- I.replaceAllUsesWith (NewI);
214- erase.push_back (&I);
215- }
216- continue ;
217- }
218-
21940 switch (I.getOpcode ()) {
22041 case Instruction::FNeg:
22142 case Instruction::FAdd:
@@ -226,9 +47,6 @@ static bool demoteFloat16(Function &F)
22647 case Instruction::FCmp:
22748 break ;
22849 default :
229- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
230- if (intrinsic->getIntrinsicID ())
231- break ;
23250 continue ;
23351 }
23452
@@ -240,67 +58,61 @@ static bool demoteFloat16(Function &F)
24058 IRBuilder<> builder (&I);
24159
24260 // extend Float16 operands to Float32
243- // XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
61+ bool OperandsChanged = false ;
24462 SmallVector<Value *, 2 > Operands (I.getNumOperands ());
24563 for (size_t i = 0 ; i < I.getNumOperands (); i++) {
24664 Value *Op = I.getOperand (i);
247- if (Op->getType ()->getScalarType ()->isHalfTy ()) {
248- Op = CreateFPCast (Instruction::FPExt, Op, Op->getType ()->getWithNewType (T_float32), builder);
65+ if (Op->getType () == T_float16) {
66+ Op = builder.CreateFPExt (Op, T_float32);
67+ OperandsChanged = true ;
24968 }
25069 Operands[i] = (Op);
25170 }
25271
25372 // recreate the instruction if any operands changed,
25473 // truncating the result back to Float16
255- Value *NewI;
256- switch (I.getOpcode ()) {
257- case Instruction::FNeg:
258- assert (Operands.size () == 1 );
259- NewI = builder.CreateFNeg (Operands[0 ]);
260- break ;
261- case Instruction::FAdd:
262- assert (Operands.size () == 2 );
263- NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
264- break ;
265- case Instruction::FSub:
266- assert (Operands.size () == 2 );
267- NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
268- break ;
269- case Instruction::FMul:
270- assert (Operands.size () == 2 );
271- NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
272- break ;
273- case Instruction::FDiv:
274- assert (Operands.size () == 2 );
275- NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
276- break ;
277- case Instruction::FRem:
278- assert (Operands.size () == 2 );
279- NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
280- break ;
281- case Instruction::FCmp:
282- assert (Operands.size () == 2 );
283- NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
284- Operands[0 ], Operands[1 ]);
285- break ;
286- default :
287- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
288- // XXX: this is not correct in general
289- // some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
290- Type *RetTy = I.getType ();
291- if (RetTy->getScalarType ()->isHalfTy ())
292- RetTy = RetTy->getWithNewType (T_float32);
293- NewI = replaceIntrinsicWith (intrinsic, RetTy, Operands);
74+ if (OperandsChanged) {
75+ Value *NewI;
76+ switch (I.getOpcode ()) {
77+ case Instruction::FNeg:
78+ assert (Operands.size () == 1 );
79+ NewI = builder.CreateFNeg (Operands[0 ]);
80+ break ;
81+ case Instruction::FAdd:
82+ assert (Operands.size () == 2 );
83+ NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
84+ break ;
85+ case Instruction::FSub:
86+ assert (Operands.size () == 2 );
87+ NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
88+ break ;
89+ case Instruction::FMul:
90+ assert (Operands.size () == 2 );
91+ NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
92+ break ;
93+ case Instruction::FDiv:
94+ assert (Operands.size () == 2 );
95+ NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
96+ break ;
97+ case Instruction::FRem:
98+ assert (Operands.size () == 2 );
99+ NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
100+ break ;
101+ case Instruction::FCmp:
102+ assert (Operands.size () == 2 );
103+ NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
104+ Operands[0 ], Operands[1 ]);
294105 break ;
106+ default :
107+ abort ();
295108 }
296- abort ();
109+ cast<Instruction>(NewI)->copyMetadata (I);
110+ cast<Instruction>(NewI)->copyFastMathFlags (&I);
111+ if (NewI->getType () != I.getType ())
112+ NewI = builder.CreateFPTrunc (NewI, I.getType ());
113+ I.replaceAllUsesWith (NewI);
114+ erase.push_back (&I);
297115 }
298- cast<Instruction>(NewI)->copyMetadata (I);
299- cast<Instruction>(NewI)->copyFastMathFlags (&I);
300- if (NewI->getType () != I.getType ())
301- NewI = CreateFPCast (Instruction::FPTrunc, NewI, I.getType (), builder);
302- I.replaceAllUsesWith (NewI);
303- erase.push_back (&I);
304116 }
305117 }
306118
0 commit comments