@@ -4,6 +4,7 @@ use llvm::Linkage::*;
44use rustc_abi:: Align ;
55use rustc_codegen_ssa:: back:: write:: CodegenContext ;
66use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7+ use rustc_middle:: ty:: offload_meta:: OffloadMetadata ;
78use rustc_middle:: ty:: { self , PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
89
910use crate :: builder:: SBuilder ;
@@ -189,8 +190,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
189190 tcx : TyCtxt < ' tcx > ,
190191 kernel : & ' ll llvm:: Value ,
191192 offload_entry_ty : & ' ll llvm:: Type ,
192- // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
193- tt : Vec < Ty < ' tcx > > ,
193+ metadata : Vec < OffloadMetadata > ,
194194 symbol : & str ,
195195) -> & ' ll llvm:: Value {
196196 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
@@ -201,12 +201,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
201201 . filter ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) )
202202 . count ( ) ;
203203
204- // TODO(Sa4dUs): Add typetrees here
205204 let ptr_sizes = types
206205 . iter ( )
207- . zip ( tt )
208- . filter_map ( |( & x, ty ) | match cx. type_kind ( x) {
209- rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( get_payload_size ( tcx , ty ) ) ,
206+ . zip ( metadata )
207+ . filter_map ( |( & x, meta ) | match cx. type_kind ( x) {
208+ rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( meta . payload_size ) ,
210209 _ => None ,
211210 } )
212211 . collect :: < Vec < u64 > > ( ) ;
@@ -265,56 +264,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
265264 o_types
266265}
267266
268- // TODO(Sa4dUs): move this to a proper place
269- fn get_payload_size < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> u64 {
270- match ty. kind ( ) {
271- /*
272- rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
273- rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
274- rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
275- rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
276- rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
277- rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
278- rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
279- rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
280- rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
281- rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
282- rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
283- rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
284- */
285- ty:: Ref ( _, inner, _) => get_payload_size ( tcx, * inner) ,
286- /*
287- rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
288- rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
289- rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
290- rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
291- rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
292- rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
293- rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
294- rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
295- rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
296- rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
297- rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
298- rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
299- rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
300- rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
301- rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
302- rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
303- */
304- _ => {
305- tcx
306- // TODO(Sa4dUs): Maybe `.as_query_input()`?
307- . layout_of ( PseudoCanonicalInput {
308- typing_env : TypingEnv :: fully_monomorphized ( ) ,
309- value : ty,
310- } )
311- . unwrap ( )
312- . size
313- . bytes ( )
314- }
315- }
316- }
317-
318267fn declare_offload_fn < ' ll > (
319268 cx : & ' ll SimpleCx < ' _ > ,
320269 name : & str ,
@@ -353,8 +302,8 @@ fn declare_offload_fn<'ll>(
353302pub ( crate ) fn gen_call_handling < ' ll > (
354303 cx : & SimpleCx < ' ll > ,
355304 bb : & BasicBlock ,
356- kernels : & [ & ' ll llvm:: Value ] ,
357- o_types : & [ & ' ll llvm:: Value ] ,
305+ kernel : & ' ll llvm:: Value ,
306+ o_type : & ' ll llvm:: Value ,
358307 llty : & ' ll Type ,
359308 llfn : & ' ll Value ,
360309) {
@@ -370,7 +319,7 @@ pub(crate) fn gen_call_handling<'ll>(
370319
371320 let mut builder = SBuilder :: build ( cx, bb) ;
372321
373- let types = cx. func_params_types ( cx. get_type_of_global ( kernels [ 0 ] ) ) ;
322+ let types = cx. func_params_types ( cx. get_type_of_global ( kernel ) ) ;
374323 let num_args = types. len ( ) as u64 ;
375324
376325 // Step 0)
@@ -392,7 +341,7 @@ pub(crate) fn gen_call_handling<'ll>(
392341 let i32_0 = cx. get_const_i32 ( 0 ) ;
393342 for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
394343 // get function arg, store it into the alloca, and read it.
395- let p = llvm:: get_param ( kernels [ 0 ] , index as u32 ) ;
344+ let p = llvm:: get_param ( kernel , index as u32 ) ;
396345 let name = llvm:: get_value_name ( p) ;
397346 let name = str:: from_utf8 ( & name) . unwrap ( ) ;
398347 let arg_name = format ! ( "{name}.addr" ) ;
@@ -471,7 +420,7 @@ pub(crate) fn gen_call_handling<'ll>(
471420
472421 // Step 2)
473422 let s_ident_t = generate_at_one ( & cx) ;
474- let o = o_types [ 0 ] ;
423+ let o = o_type ;
475424 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
476425 generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
477426
@@ -485,7 +434,7 @@ pub(crate) fn gen_call_handling<'ll>(
485434 args. push ( param) ;
486435 }
487436
488- builder. call ( llty, kernels [ 0 ] , & args, None ) ;
437+ builder. call ( llty, kernel , & args, None ) ;
489438
490439 // Step 4)
491440 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
0 commit comments