@@ -16,23 +16,41 @@ pub(crate) fn handle_gpu_code<'ll>(
1616 cx : & ' ll SimpleCx < ' _ > ,
1717) {
1818 // The offload memory transfer type for each kernel
19- let mut o_types = vec ! [ ] ;
20- let mut kernels = vec ! [ ] ;
19+ let mut memtransfer_types = vec ! [ ] ;
20+ let mut region_ids = vec ! [ ] ;
2121 let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
2222 for num in 0 ..9 {
2323 let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
2424 if let Some ( kernel) = kernel {
25- o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
26- kernels. push ( kernel) ;
25+ let ( o, k) = gen_define_handling ( & cx, kernel, offload_entry_ty, num) ;
26+ memtransfer_types. push ( o) ;
27+ region_ids. push ( k) ;
2728 }
2829 }
2930
30- gen_call_handling ( & cx, & kernels, & o_types) ;
31+ gen_call_handling ( & cx, & memtransfer_types, & region_ids) ;
32+ }
33+
34+ // ; Function Attrs: nounwind
35+ // declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
36+ fn generate_launcher < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> ( & ' ll llvm:: Value , & ' ll llvm:: Type ) {
37+ let tptr = cx. type_ptr ( ) ;
38+ let ti64 = cx. type_i64 ( ) ;
39+ let ti32 = cx. type_i32 ( ) ;
40+ let args = vec ! [ tptr, ti64, ti32, ti32, tptr, tptr] ;
41+ let tgt_fn_ty = cx. type_func ( & args, ti32) ;
42+ let name = "__tgt_target_kernel" ;
43+ let tgt_decl = declare_offload_fn ( & cx, name, tgt_fn_ty) ;
44+ let nounwind = llvm:: AttributeKind :: NoUnwind . create_attr ( cx. llcx ) ;
45+ attributes:: apply_to_llfn ( tgt_decl, Function , & [ nounwind] ) ;
46+ ( tgt_decl, tgt_fn_ty)
3147}
3248
3349// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
3450// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
3551// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
52+ // FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
53+ // offloaded was defined.
3654fn generate_at_one < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Value {
3755 // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
3856 let unknown_txt = ";unknown;unknown;0;0;;" ;
@@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
83101 offload_entry_ty
84102}
85103
86- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) {
104+ fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm :: Type {
87105 let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
88106 let tptr = cx. type_ptr ( ) ;
89107 let ti64 = cx. type_i64 ( ) ;
@@ -107,7 +125,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
107125 // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
108126 // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
109127 // uint64_t Unused : 62;
110- // } Flags = {0, 0, 0};
128+ // } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte
111129 // // The number of teams (for x,y,z dimension).
112130 // uint32_t NumTeams[3] = {0, 0, 0};
113131 // // The number of threads (for x,y,z dimension).
@@ -118,9 +136,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
118136 vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
119137
120138 cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
121- // For now we don't handle kernels, so for now we just add a global dummy
122- // to make sure that the __tgt_offload_entry is defined and handled correctly.
123- cx. declare_global ( "my_struct_global2" , kernel_arguments_ty) ;
139+ kernel_arguments_ty
124140}
125141
126142fn gen_tgt_data_mappers < ' ll > (
@@ -187,7 +203,7 @@ fn gen_define_handling<'ll>(
187203 kernel : & ' ll llvm:: Value ,
188204 offload_entry_ty : & ' ll llvm:: Type ,
189205 num : i64 ,
190- ) -> & ' ll llvm:: Value {
206+ ) -> ( & ' ll llvm:: Value , & ' ll llvm :: Value ) {
191207 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
192208 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
193209 // reference) types.
@@ -205,10 +221,14 @@ fn gen_define_handling<'ll>(
205221 // or both to and from the gpu (=3). Other values shouldn't affect us for now.
206222 // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207223 // will be 2. For now, everything is 3, until we have our frontend set up.
208- let o_types =
209- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 3 ; num_ptr_types] ) ;
224+ // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
225+ let memtransfer_types = add_priv_unnamed_arr (
226+ & cx,
227+ & format ! ( ".offload_maptypes.{num}" ) ,
228+ & vec ! [ 1 + 2 + 32 ; num_ptr_types] ,
229+ ) ;
210230 // Next: For each function, generate these three entries. A weak constant,
211- // the llvm.rodata entry name, and the omp_offloading_entries value
231+ // the llvm.rodata entry name, and the llvm_offload_entries value
212232
213233 let name = format ! ( ".kernel_{num}.region_id" ) ;
214234 let initializer = cx. get_const_i8 ( 0 ) ;
@@ -242,13 +262,13 @@ fn gen_define_handling<'ll>(
242262 llvm:: set_global_constant ( llglobal, true ) ;
243263 llvm:: set_linkage ( llglobal, WeakAnyLinkage ) ;
244264 llvm:: set_initializer ( llglobal, initializer) ;
245- llvm:: set_alignment ( llglobal, Align :: ONE ) ;
246- let c_section_name = CString :: new ( ".omp_offloading_entries " ) . unwrap ( ) ;
265+ llvm:: set_alignment ( llglobal, Align :: EIGHT ) ;
266+ let c_section_name = CString :: new ( "llvm_offload_entries " ) . unwrap ( ) ;
247267 llvm:: set_section ( llglobal, & c_section_name) ;
248- o_types
268+ ( memtransfer_types , region_id )
249269}
250270
251- fn declare_offload_fn < ' ll > (
271+ pub ( crate ) fn declare_offload_fn < ' ll > (
252272 cx : & ' ll SimpleCx < ' _ > ,
253273 name : & str ,
254274 ty : & ' ll llvm:: Type ,
@@ -285,17 +305,18 @@ fn declare_offload_fn<'ll>(
285305// 6. generate __tgt_target_data_end calls to move data from the GPU
286306fn gen_call_handling < ' ll > (
287307 cx : & ' ll SimpleCx < ' _ > ,
288- _kernels : & [ & ' ll llvm:: Value ] ,
289- o_types : & [ & ' ll llvm:: Value ] ,
308+ memtransfer_types : & [ & ' ll llvm:: Value ] ,
309+ region_ids : & [ & ' ll llvm:: Value ] ,
290310) {
311+ let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
291312 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292313 let tptr = cx. type_ptr ( ) ;
293314 let ti32 = cx. type_i32 ( ) ;
294315 let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
295316 let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
296317 cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
297318
298- gen_tgt_kernel_global ( & cx) ;
319+ let tgt_kernel_decl = gen_tgt_kernel_global ( & cx) ;
299320 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
300321
301322 let main_fn = cx. get_function ( "main" ) ;
@@ -329,35 +350,32 @@ fn gen_call_handling<'ll>(
329350 // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
330351 let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
331352 let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
353+
354+ //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
355+ let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
356+
357+ // Step 1)
358+ unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
359+ builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
360+
332361 // Now we allocate once per function param, a copy to be passed to one of our maps.
333362 let mut vals = vec ! [ ] ;
334363 let mut geps = vec ! [ ] ;
335364 let i32_0 = cx. get_const_i32 ( 0 ) ;
336- for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
337- // get function arg, store it into the alloca, and read it.
338- let p = llvm:: get_param ( called, index as u32 ) ;
339- let name = llvm:: get_value_name ( p) ;
340- let name = str:: from_utf8 ( & name) . unwrap ( ) ;
341- let arg_name = format ! ( "{name}.addr" ) ;
342- let alloca = builder. direct_alloca ( in_ty, Align :: EIGHT , & arg_name) ;
343-
344- builder. store ( p, alloca, Align :: EIGHT ) ;
345- let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
346- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
347- vals. push ( val) ;
365+ for index in 0 ..types. len ( ) {
366+ let v = unsafe { llvm:: LLVMGetOperand ( kernel_call, index as u32 ) . unwrap ( ) } ;
367+ let gep = builder. inbounds_gep ( cx. type_f32 ( ) , v, & [ i32_0] ) ;
368+ vals. push ( v) ;
348369 geps. push ( gep) ;
349370 }
350371
351- // Step 1)
352- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
353- builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
354-
355372 let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
356373 let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
357374 let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
358375 let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
359376 let init_rtls_decl = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
360377
378+ // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
361379 // call void @__tgt_register_lib(ptr noundef %6)
362380 builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
363381 // call void @__tgt_init_all_rtls()
@@ -415,22 +433,63 @@ fn gen_call_handling<'ll>(
415433
416434 // Step 2)
417435 let s_ident_t = generate_at_one ( & cx) ;
418- let o = o_types [ 0 ] ;
436+ let o = memtransfer_types [ 0 ] ;
419437 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
420438 generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
421439
422440 // Step 3)
423- // Here we will add code for the actual kernel launches in a follow-up PR.
424- // FIXME(offload): launch kernels
441+ let mut values = vec ! [ ] ;
442+ let offload_version = cx. get_const_i32 ( 3 ) ;
443+ values. push ( ( 4 , offload_version) ) ;
444+ values. push ( ( 4 , cx. get_const_i32 ( num_args) ) ) ;
445+ values. push ( ( 8 , geps. 0 ) ) ;
446+ values. push ( ( 8 , geps. 1 ) ) ;
447+ values. push ( ( 8 , geps. 2 ) ) ;
448+ values. push ( ( 8 , memtransfer_types[ 0 ] ) ) ;
449+ // The next two are debug infos. FIXME(offload) set them
450+ values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
451+ values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
452+ values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
453+ values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
454+ let ti32 = cx. type_i32 ( ) ;
455+ let ci32_0 = cx. get_const_i32 ( 0 ) ;
456+ values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
457+ values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
458+ values. push ( ( 4 , cx. get_const_i32 ( 0 ) ) ) ;
459+
460+ for ( i, value) in values. iter ( ) . enumerate ( ) {
461+ let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
462+ builder. store ( value. 1 , ptr, Align :: from_bytes ( value. 0 ) . unwrap ( ) ) ;
463+ }
464+
465+ let args = vec ! [
466+ s_ident_t,
467+ // MAX == -1
468+ cx. get_const_i64( u64 :: MAX ) ,
469+ cx. get_const_i32( 2097152 ) ,
470+ cx. get_const_i32( 256 ) ,
471+ region_ids[ 0 ] ,
472+ a5,
473+ ] ;
474+ let offload_success = builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
475+ // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
476+ unsafe {
477+ let next = llvm:: LLVMGetNextInstruction ( offload_success) . unwrap ( ) ;
478+ llvm:: LLVMRustPositionAfter ( builder. llbuilder , next) ;
479+ llvm:: LLVMInstructionEraseFromParent ( next) ;
480+ }
425481
426482 // Step 4)
427- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
483+ // unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
428484
429485 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
430486 generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
431487
432488 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
433489
490+ drop ( builder) ;
491+ unsafe { llvm:: LLVMDeleteFunction ( called) } ;
492+
434493 // With this we generated the following begin and end mappers. We could easily generate the
435494 // update mapper in an update.
436495 // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
0 commit comments