@@ -653,6 +653,87 @@ pub(crate) unsafe fn llvm_optimize(
653653        None 
654654    } ; 
655655
656+     fn  handle_offload ( m :  & llvm:: Module ,  llcx :  & llvm:: Context ,  old_fn :  & llvm:: Value )  { 
657+         unsafe  {  llvm:: LLVMRustOffloadWrapper ( m,  old_fn)  } ; 
658+         //unsafe {llvm::LLVMDumpModule(m);} 
659+         //unsafe { 
660+         //    // Get the old function type 
661+         //    let old_fn_ty = llvm::LLVMGlobalGetValueType(old_fn); 
662+         //    dbg!(&old_fn_ty); 
663+         //    let old_param_count = llvm::LLVMCountParamTypes(old_fn_ty); 
664+         //    dbg!(&old_param_count); 
665+ 
666+         //    // Get the old parameter types 
667+         //    let mut old_param_types = Vec::with_capacity(old_param_count as usize); 
668+         //    llvm::LLVMGetParamTypes(old_fn_ty, old_param_types.as_mut_ptr()); 
669+         //    old_param_types.set_len(old_param_count as usize); 
670+ 
671+         //    // Create the new parameter list, with ptr as the first argument 
672+         //    let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0); 
673+         //    let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); 
674+         //    new_param_types.push(ptr_ty); 
675+         //    for old_param in old_param_types { 
676+         //        new_param_types.push(old_param); 
677+         //    } 
678+         //    dbg!(&new_param_types); 
679+ 
680+         //    // Create the new function type 
681+         //    let ret_ty = llvm::LLVMGetReturnType(old_fn_ty); 
682+         //    let new_fn_ty = llvm::LLVMFunctionType(ret_ty, new_param_types.as_mut_ptr(), new_param_types.len() as u32, 0); 
683+         //    dbg!(&new_fn_ty); 
684+ 
685+         //    // Create the new function 
686+         //    let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap(); 
687+         //    //let old_fn_name = std::ffi::CStr::from_ptr(llvm::LLVMGetValueName2(old_fn)).to_str().unwrap(); 
688+         //    let new_fn_name = format!("{}_with_dyn_ptr", old_fn_name); 
689+         //    let new_fn_cstr = CString::new(new_fn_name).unwrap(); 
690+         //    let new_fn = llvm::LLVMAddFunction(m, new_fn_cstr.as_ptr(), new_fn_ty); 
691+         //    dbg!(&new_fn); 
692+         //    let a0 = llvm::LLVMGetParam(new_fn, 0); 
693+         //    llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr().cast(), "dyn_ptr".len()); 
694+         //    dbg!(&new_fn); 
695+ 
696+         //    // Move basic blocks 
697+         //    let mut bb = llvm::LLVMGetFirstBasicBlock(old_fn); 
698+         //    //dbg!(&bb); 
699+         //    llvm::LLVMAppendExistingBasicBlock(new_fn, bb); 
700+         //    //while !bb.is_null() { 
701+         //    //    let next = llvm::LLVMGetNextBasicBlock(bb); 
702+         //    //    llvm::LLVMAppendExistingBasicBlock(new_fn, bb); 
703+         //    //    bb = next; 
704+         //    //}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ... 
705+         //    let old_n = llvm::LLVMCountParams(old_fn); 
706+         //    for i in 0..old_n { 
707+         //        let old_arg = llvm::LLVMGetParam(old_fn, i); 
708+         //        let new_arg = llvm::LLVMGetParam(new_fn, i + 1); 
709+         //        llvm::LLVMReplaceAllUsesWith(old_arg, new_arg); 
710+         //    } 
711+ 
712+         //    // Copy linkage and visibility 
713+         //    //llvm::LLVMSetLinkage(new_fn, llvm::LLVMGetLinkage(old_fn)); 
714+         //    //llvm::LLVMSetVisibility(new_fn, llvm::LLVMGetVisibility(old_fn)); 
715+ 
716+         //    // Replace all uses of old_fn with new_fn (RAUW) 
717+         //    llvm::LLVMReplaceAllUsesWith(old_fn, new_fn); 
718+ 
719+         //    // Optionally, remove the old function 
720+         //    llvm::LLVMDeleteFunction(old_fn); 
721+         //} 
722+     } 
723+ 
724+     let  consider_offload = config. offload . contains ( & config:: Offload :: Enable ) ; 
725+     if  consider_offload && ( cgcx. target_arch  == "amdgpu"  || cgcx. target_arch  == "nvptx64" )  { 
726+         for  num in  0 ..9  { 
727+             let  name = format ! ( "kernel_{num}" ) ; 
728+             let  c_name = CString :: new ( name) . unwrap ( ) ; 
729+             if  let  Some ( kernel)  =
730+                 unsafe  {  llvm:: LLVMGetNamedFunction ( module. module_llvm . llmod ( ) ,  c_name. as_ptr ( ) )  } 
731+             { 
732+                 handle_offload ( module. module_llvm . llmod ( ) ,  module. module_llvm . llcx ,  kernel) ; 
733+             } 
734+         } 
735+     } 
736+ 
656737    let  mut  llvm_profiler = cgcx
657738        . prof 
658739        . llvm_recording_enabled ( ) 
0 commit comments