Skip to content

Commit 4496fb8

Browse files
committed
first definition of offload intrinsic (dirty code)
1 parent 28c4c7d commit 4496fb8

File tree

7 files changed

+214
-36
lines changed

7 files changed

+214
-36
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
78

89
use crate::builder::SBuilder;
9-
use crate::common::AsCCharPtr;
1010
use crate::llvm::AttributePlace::Function;
11-
use crate::llvm::{self, Linkage, Type, Value};
11+
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
1212
use crate::{LlvmCodegenBackend, SimpleCx, attributes};
1313

1414
pub(crate) fn handle_gpu_code<'ll>(
1515
_cgcx: &CodegenContext<LlvmCodegenBackend>,
16-
cx: &'ll SimpleCx<'_>,
16+
_cx: &'ll SimpleCx<'_>,
1717
) {
18+
/*
1819
// The offload memory transfer type for each kernel
1920
let mut o_types = vec![];
2021
let mut kernels = vec![];
@@ -28,6 +29,7 @@ pub(crate) fn handle_gpu_code<'ll>(
2829
}
2930
3031
gen_call_handling(&cx, &kernels, &o_types);
32+
*/
3133
}
3234

3335
// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
@@ -83,7 +85,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
8385
offload_entry_ty
8486
}
8587

86-
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
88+
pub(crate) fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
8789
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
8890
let tptr = cx.type_ptr();
8991
let ti64 = cx.type_i64();
@@ -182,11 +184,14 @@ pub(crate) fn add_global<'ll>(
182184
llglobal
183185
}
184186

185-
fn gen_define_handling<'ll>(
186-
cx: &'ll SimpleCx<'_>,
187+
pub(crate) fn gen_define_handling<'ll, 'tcx>(
188+
cx: &SimpleCx<'ll>,
189+
tcx: TyCtxt<'tcx>,
187190
kernel: &'ll llvm::Value,
188191
offload_entry_ty: &'ll llvm::Type,
189-
num: i64,
192+
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
193+
tt: Vec<Ty<'tcx>>,
194+
symbol: &str,
190195
) -> &'ll llvm::Value {
191196
let types = cx.func_params_types(cx.get_type_of_global(kernel));
192197
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -196,35 +201,47 @@ fn gen_define_handling<'ll>(
196201
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
197202
.count();
198203

204+
// TODO(Sa4dUs): Add typetrees here
205+
let ptr_sizes = types
206+
.iter()
207+
.zip(tt)
208+
.filter_map(|(&x, ty)| match cx.type_kind(x) {
209+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
210+
_ => None,
211+
})
212+
.collect::<Vec<u64>>();
213+
199214
// We do not know their size anymore at this level, so hardcode a placeholder.
200215
// A follow-up pr will track these from the frontend, where we still have Rust types.
201216
// Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
202217
// I decided that 1024 bytes is a great placeholder value for now.
203-
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
218+
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
204219
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
205220
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
206221
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207222
// will be 2. For now, everything is 3, until we have our frontend set up.
223+
224+
// TODO(Sa4dUs): Check the way to figure out this
208225
let o_types =
209-
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]);
226+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &vec![3; num_ptr_types]);
210227
// Next: For each function, generate these three entries. A weak constant,
211228
// the llvm.rodata entry name, and the omp_offloading_entries value
212229

213-
let name = format!(".kernel_{num}.region_id");
230+
let name = format!(".{symbol}.region_id");
214231
let initializer = cx.get_const_i8(0);
215232
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
216233

217-
let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
234+
let c_entry_name = CString::new(symbol).unwrap();
218235
let c_val = c_entry_name.as_bytes_with_nul();
219-
let offload_entry_name = format!(".offloading.entry_name.{num}");
236+
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
220237

221238
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
222239
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
223240
llvm::set_alignment(llglobal, Align::ONE);
224241
llvm::set_section(llglobal, c".llvm.rodata.offloading");
225242

226243
// Not actively used yet, for calling real kernels
227-
let name = format!(".offloading.entry.kernel_{num}");
244+
let name = format!(".offloading.entry.{symbol}");
228245

229246
// See the __tgt_offload_entry documentation above.
230247
let reserved = cx.get_const_i64(0);
@@ -248,6 +265,56 @@ fn gen_define_handling<'ll>(
248265
o_types
249266
}
250267

268+
// TODO(Sa4dUs): move this to a proper place
269+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
270+
match ty.kind() {
271+
/*
272+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
273+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
274+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
275+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
276+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
277+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
278+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
279+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
280+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
281+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
282+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
283+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
284+
*/
285+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
286+
/*
287+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
288+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
289+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
290+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
291+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
292+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
293+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
294+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
295+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
296+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
297+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
298+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
299+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
300+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
301+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
302+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
303+
*/
304+
_ => {
305+
tcx
306+
// TODO(Sa4dUs): Maybe `.as_query_input()`?
307+
.layout_of(PseudoCanonicalInput {
308+
typing_env: TypingEnv::fully_monomorphized(),
309+
value: ty,
310+
})
311+
.unwrap()
312+
.size
313+
.bytes()
314+
}
315+
}
316+
}
317+
251318
fn declare_offload_fn<'ll>(
252319
cx: &'ll SimpleCx<'_>,
253320
name: &str,
@@ -283,10 +350,13 @@ fn declare_offload_fn<'ll>(
283350
// 4. set insert point after kernel call.
284351
// 5. generate all the GEPS and stores, to be used in 6)
285352
// 6. generate __tgt_target_data_end calls to move data from the GPU
286-
fn gen_call_handling<'ll>(
287-
cx: &'ll SimpleCx<'_>,
288-
_kernels: &[&'ll llvm::Value],
353+
pub(crate) fn gen_call_handling<'ll>(
354+
cx: &SimpleCx<'ll>,
355+
bb: &BasicBlock,
356+
kernels: &[&'ll llvm::Value],
289357
o_types: &[&'ll llvm::Value],
358+
llty: &'ll Type,
359+
llfn: &'ll Value,
290360
) {
291361
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292362
let tptr = cx.type_ptr();
@@ -298,27 +368,14 @@ fn gen_call_handling<'ll>(
298368
gen_tgt_kernel_global(&cx);
299369
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
300370

301-
let main_fn = cx.get_function("main");
302-
let Some(main_fn) = main_fn else { return };
303-
let kernel_name = "kernel_1";
304-
let call = unsafe {
305-
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
306-
};
307-
let Some(kernel_call) = call else {
308-
return;
309-
};
310-
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
311-
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
312-
let mut builder = SBuilder::build(cx, kernel_call_bb);
313-
314-
let types = cx.func_params_types(cx.get_type_of_global(called));
371+
let mut builder = SBuilder::build(cx, bb);
372+
373+
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
315374
let num_args = types.len() as u64;
316375

317376
// Step 0)
318377
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
319378
// %6 = alloca %struct.__tgt_bin_desc, align 8
320-
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
321-
322379
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
323380

324381
let ty = cx.type_array(cx.type_ptr(), num_args);
@@ -335,7 +392,7 @@ fn gen_call_handling<'ll>(
335392
let i32_0 = cx.get_const_i32(0);
336393
for (index, in_ty) in types.iter().enumerate() {
337394
// get function arg, store it into the alloca, and read it.
338-
let p = llvm::get_param(called, index as u32);
395+
let p = llvm::get_param(kernels[0], index as u32);
339396
let name = llvm::get_value_name(p);
340397
let name = str::from_utf8(&name).unwrap();
341398
let arg_name = format!("{name}.addr");
@@ -349,7 +406,6 @@ fn gen_call_handling<'ll>(
349406
}
350407

351408
// Step 1)
352-
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
353409
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
354410

355411
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
@@ -422,10 +478,16 @@ fn gen_call_handling<'ll>(
422478
// Step 3)
423479
// Here we will add code for the actual kernel launches in a follow-up PR.
424480
// FIXME(offload): launch kernels
481+
let nparams = llvm::LLVMCountParams(llfn);
482+
let mut args = Vec::with_capacity(nparams as usize);
483+
for i in 0..nparams {
484+
let param = unsafe { llvm::LLVMGetParam(llfn, i) };
485+
args.push(param);
486+
}
425487

426-
// Step 4)
427-
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
488+
builder.call(llty, kernels[0], &args, None);
428489

490+
// Step 4)
429491
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
430492
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
431493

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
195195
codegen_autodiff(self, tcx, instance, args, result);
196196
return Ok(());
197197
}
198+
sym::offload => {
199+
codegen_offload(self, tcx, instance, args, result);
200+
return Ok(());
201+
}
198202
sym::is_val_statically_known => {
199203
if let OperandValue::Immediate(imm) = args[0].val {
200204
self.call_intrinsic(
@@ -1227,6 +1231,73 @@ fn codegen_autodiff<'ll, 'tcx>(
12271231
);
12281232
}
12291233

1234+
fn codegen_offload<'ll, 'tcx>(
1235+
bx: &mut Builder<'_, 'll, 'tcx>,
1236+
tcx: TyCtxt<'tcx>,
1237+
instance: ty::Instance<'tcx>,
1238+
_args: &[OperandRef<'tcx, &'ll Value>],
1239+
_result: PlaceRef<'tcx, &'ll Value>,
1240+
) {
1241+
let cx = bx.cx;
1242+
let fn_args = instance.args;
1243+
1244+
let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() {
1245+
ty::FnDef(def_id, params) => (def_id, params),
1246+
_ => bug!("invalid offload intrinsic arg"),
1247+
};
1248+
1249+
let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) {
1250+
Ok(Some(instance)) => instance,
1251+
Ok(None) => bug!(
1252+
"could not resolve ({:?}, {:?}) to a specific offload instance",
1253+
target_id,
1254+
target_args
1255+
),
1256+
Err(_) => {
1257+
// An error has already been emitted
1258+
return;
1259+
}
1260+
};
1261+
1262+
// TODO(Sa4dUs): Will need typetrees
1263+
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
1264+
let Some(kernel) = cx.get_function(&target_symbol) else {
1265+
bug!("could not find target function")
1266+
};
1267+
1268+
let offload_entry_ty = crate::builder::gpu_offload::add_tgt_offload_entry(cx);
1269+
1270+
// Build TypeTree (or something similar)
1271+
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
1272+
let inputs = sig.inputs();
1273+
1274+
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
1275+
let o = crate::builder::gpu_offload::gen_define_handling(
1276+
cx,
1277+
tcx,
1278+
kernel,
1279+
offload_entry_ty,
1280+
inputs.to_vec(),
1281+
&target_symbol,
1282+
);
1283+
1284+
let kernels = &[kernel];
1285+
let o_types = &[o];
1286+
1287+
let llvm_args = inputs.iter().map(|ty| bx.layout_of(*ty).llvm_type(cx)).collect::<Vec<_>>();
1288+
let ret_ty = match sig.output().kind() {
1289+
// TODO(Sa4dUs): dunno if there's a better way of doing this
1290+
ty::Tuple(tys) if tys.is_empty() => bx.type_void(),
1291+
_ => bx.layout_of(sig.output()).llvm_type(cx),
1292+
};
1293+
let llty = bx.type_func(&llvm_args, ret_ty);
1294+
let llfn = bx.llfn();
1295+
1296+
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1297+
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1298+
crate::builder::gpu_offload::gen_call_handling(cx, bb, kernels, o_types, llty, llfn);
1299+
}
1300+
12301301
fn get_args_from_tuple<'ll, 'tcx>(
12311302
bx: &mut Builder<'_, 'll, 'tcx>,
12321303
tuple_op: OperandRef<'tcx, &'ll Value>,

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//!
55
//! This API is completely unstable and subject to change.
66
7+
#![allow(unused)]
78
// tidy-alphabetical-start
89
#![allow(internal_features)]
910
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
163163
| sym::minnumf128
164164
| sym::mul_with_overflow
165165
| sym::needs_drop
166+
| sym::offload
166167
| sym::powf16
167168
| sym::powf32
168169
| sym::powf64
@@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type(
310311
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
311312
(0, 0, vec![type_id, type_id], tcx.types.bool)
312313
}
314+
sym::offload => (2, 0, vec![param(0)], param(1)),
313315
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
314316
sym::arith_offset => (
315317
1,

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,7 @@ symbols! {
15581558
object_safe_for_dispatch,
15591559
of,
15601560
off,
1561+
offload,
15611562
offset,
15621563
offset_of,
15631564
offset_of_enum,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3262,6 +3262,10 @@ pub const fn copysignf128(x: f128, y: f128) -> f128;
32623262
#[rustc_intrinsic]
32633263
pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
32643264

3265+
#[rustc_nounwind]
3266+
#[rustc_intrinsic]
3267+
pub const fn offload<F, R>(f: F) -> R;
3268+
32653269
/// Inform Miri that a given pointer definitely has a certain alignment.
32663270
#[cfg(miri)]
32673271
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)