@@ -43,7 +43,7 @@ use crate::errors::{
4343use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
4444use crate :: llvm:: { self , DiagnosticInfo } ;
4545use crate :: type_:: llvm_type_ptr;
46- use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
46+ use crate :: { LlvmCodegenBackend , ModuleLlvm , SimpleCx , base, common, llvm_util} ;
4747
4848pub ( crate ) fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> ! {
4949 match llvm:: last_error ( ) {
@@ -645,6 +645,74 @@ pub(crate) unsafe fn llvm_optimize(
645645 None
646646 } ;
647647
648+ fn handle_offload < ' ll > ( cx : & ' ll SimpleCx < ' _ > , old_fn : & llvm:: Value ) {
649+ let old_fn_ty = cx. get_type_of_global ( old_fn) ;
650+ let old_param_types = cx. func_params_types ( old_fn_ty) ;
651+ let old_param_count = old_param_types. len ( ) ;
652+ if old_param_count == 0 {
653+ return ;
654+ }
655+
656+ let first_param = llvm:: get_param ( old_fn, 0 ) ;
657+ let c_name = llvm:: get_value_name ( first_param) ;
658+ let first_arg_name = str:: from_utf8 ( & c_name) . unwrap ( ) ;
659+ // We might call llvm_optimize (and thus this code) multiple times on the same IR,
660+ // but we shouldn't add this helper ptr multiple times.
661+ // FIXME(offload): This could break if the user calls his first argument `dyn_ptr`.
662+ if first_arg_name == "dyn_ptr" {
663+ return ;
664+ }
665+
666+ // Create the new parameter list, with ptr as the first argument
667+ let mut new_param_types = Vec :: with_capacity ( old_param_count as usize + 1 ) ;
668+ new_param_types. push ( cx. type_ptr ( ) ) ;
669+ new_param_types. extend ( old_param_types) ;
670+
671+ // Create the new function type
672+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( old_fn_ty) } ;
673+ let new_fn_ty = cx. type_func ( & new_param_types, ret_ty) ;
674+
675+ // Create the new function, with a temporary .offload name to avoid a name collision.
676+ let old_fn_name = String :: from_utf8 ( llvm:: get_value_name ( old_fn) ) . unwrap ( ) ;
677+ let new_fn_name = format ! ( "{}.offload" , & old_fn_name) ;
678+ let new_fn = cx. add_func ( & new_fn_name, new_fn_ty) ;
679+ let a0 = llvm:: get_param ( new_fn, 0 ) ;
680+ llvm:: set_value_name ( a0, CString :: new ( "dyn_ptr" ) . unwrap ( ) . as_bytes ( ) ) ;
681+
682+ // Here we map the old arguments to the new arguments, with an offset of 1 to make sure
683+ // that we don't use the newly added `%dyn_ptr`.
684+ unsafe {
685+ llvm:: LLVMRustOffloadMapper ( cx. llmod ( ) , old_fn, new_fn) ;
686+ }
687+
688+ llvm:: set_linkage ( new_fn, llvm:: get_linkage ( old_fn) ) ;
689+ llvm:: set_visibility ( new_fn, llvm:: get_visibility ( old_fn) ) ;
690+
691+ // Replace all uses of old_fn with new_fn (RAUW)
692+ unsafe {
693+ llvm:: LLVMReplaceAllUsesWith ( old_fn, new_fn) ;
694+ }
695+ let name = llvm:: get_value_name ( old_fn) ;
696+ unsafe {
697+ llvm:: LLVMDeleteFunction ( old_fn) ;
698+ }
699+ // Now we can re-use the old name, without name collision.
700+ llvm:: set_value_name ( new_fn, & name) ;
701+ }
702+
703+ if cgcx. target_is_like_gpu && config. offload . contains ( & config:: Offload :: Enable ) {
704+ let cx =
705+ SimpleCx :: new ( module. module_llvm . llmod ( ) , module. module_llvm . llcx , cgcx. pointer_size ) ;
706+ // For now we only support up to 10 kernels named kernel_0 ... kernel_9, a follow-up PR is
707+ // introducing a proper offload intrinsic to solve this limitation.
708+ for num in 0 ..9 {
709+ let name = format ! ( "kernel_{num}" ) ;
710+ if let Some ( kernel) = cx. get_function ( & name) {
711+ handle_offload ( & cx, kernel) ;
712+ }
713+ }
714+ }
715+
648716 let mut llvm_profiler = cgcx
649717 . prof
650718 . llvm_recording_enabled ( )
0 commit comments