Skip to content

Commit 23722aa

Browse files
committed
Add basic offload metadata
1 parent f12ac2b commit 23722aa

File tree

4 files changed

+84
-64
lines changed

4 files changed

+84
-64
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::offload_meta::OffloadMetadata;
78
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
89

910
use crate::builder::SBuilder;
@@ -260,8 +261,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
260261
tcx: TyCtxt<'tcx>,
261262
kernel: &'ll llvm::Value,
262263
offload_entry_ty: &'ll llvm::Type,
263-
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
264-
tt: Vec<Ty<'tcx>>,
264+
metadata: Vec<OffloadMetadata>,
265265
symbol: &str,
266266
) -> (&'ll llvm::Value, &'ll llvm::Value) {
267267
let types = cx.func_params_types(cx.get_type_of_global(kernel));
@@ -272,12 +272,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
272272
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
273273
.count();
274274

275-
// TODO(Sa4dUs): Add typetrees here
276275
let ptr_sizes = types
277276
.iter()
278-
.zip(tt)
279-
.filter_map(|(&x, ty)| match cx.type_kind(x) {
280-
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
277+
.zip(metadata)
278+
.filter_map(|(&x, meta)| match cx.type_kind(x) {
279+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size),
281280
_ => None,
282281
})
283282
.collect::<Vec<u64>>();
@@ -332,56 +331,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>(
332331
(memtransfer_types, region_id)
333332
}
334333

335-
// TODO(Sa4dUs): move this to a proper place
336-
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
337-
match ty.kind() {
338-
/*
339-
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
340-
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
341-
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
342-
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
343-
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
344-
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
345-
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
346-
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
347-
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
348-
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
349-
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
350-
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
351-
*/
352-
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
353-
/*
354-
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
355-
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
356-
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
357-
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
358-
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
359-
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
360-
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
361-
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
362-
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
363-
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
364-
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
365-
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
366-
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
367-
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
368-
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
369-
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
370-
*/
371-
_ => {
372-
tcx
373-
// TODO(Sa4dUs): Maybe `.as_query_input()`?
374-
.layout_of(PseudoCanonicalInput {
375-
typing_env: TypingEnv::fully_monomorphized(),
376-
value: ty,
377-
})
378-
.unwrap()
379-
.size
380-
.bytes()
381-
}
382-
}
383-
}
384-
385334
fn declare_offload_fn<'ll>(
386335
cx: &'ll SimpleCx<'_>,
387336
name: &str,
@@ -420,7 +369,7 @@ fn declare_offload_fn<'ll>(
420369
pub(crate) fn gen_call_handling<'ll>(
421370
cx: &SimpleCx<'ll>,
422371
bb: &BasicBlock,
423-
kernels: &[&'ll llvm::Value],
372+
kernel: &'ll llvm::Value,
424373
memtransfer_types: &[&'ll llvm::Value],
425374
region_ids: &[&'ll llvm::Value],
426375
llfn: &'ll Value,
@@ -438,7 +387,7 @@ pub(crate) fn gen_call_handling<'ll>(
438387

439388
let mut builder = SBuilder::build(cx, bb);
440389

441-
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
390+
let types = cx.func_params_types(cx.get_type_of_global(kernel));
442391
let num_args = types.len() as u64;
443392

444393
// Step 0)

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1313
use rustc_hir::{self as hir};
1414
use rustc_middle::mir::BinOp;
1515
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
16+
use rustc_middle::ty::offload_meta::OffloadMetadata;
1617
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
1718
use rustc_middle::{bug, span_bug};
1819
use rustc_span::{Span, Symbol, sym};
@@ -1260,7 +1261,6 @@ fn codegen_offload<'ll, 'tcx>(
12601261
}
12611262
};
12621263

1263-
// TODO(Sa4dUs): Will need typetrees
12641264
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
12651265
let Some(kernel) = cx.get_function(&target_symbol) else {
12661266
bug!("could not find target function")
@@ -1272,26 +1272,26 @@ fn codegen_offload<'ll, 'tcx>(
12721272
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
12731273
let inputs = sig.inputs();
12741274

1275+
let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
1276+
12751277
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
12761278
let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling(
12771279
cx,
12781280
tcx,
12791281
kernel,
12801282
offload_entry_ty,
1281-
inputs.to_vec(),
1283+
metadata,
12821284
&target_symbol,
12831285
);
12841286

1285-
let kernels = &[kernel];
1286-
12871287
let llfn = bx.llfn();
12881288

1289-
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1289+
// TODO(Sa4dUs): this is just to a void lifetime's issues
12901290
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
12911291
crate::builder::gpu_offload::gen_call_handling(
12921292
cx,
12931293
bb,
1294-
kernels,
1294+
kernel,
12951295
&[memtransfer_type],
12961296
&[region_id],
12971297
llfn,

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ pub mod fast_reject;
131131
pub mod inhabitedness;
132132
pub mod layout;
133133
pub mod normalize_erasing_regions;
134+
pub mod offload_meta;
134135
pub mod pattern;
135136
pub mod print;
136137
pub mod relate;
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
2+
3+
// TODO(Sa4dUs): it doesn't feel correct for me to place this on `rustc_ast::expand`, will look for a proper location
4+
pub struct OffloadMetadata {
5+
pub payload_size: u64,
6+
pub mode: TransferKind,
7+
}
8+
9+
pub enum TransferKind {
10+
FromGpu = 1,
11+
ToGpu = 2,
12+
Both = 3,
13+
}
14+
15+
impl OffloadMetadata {
16+
pub fn new(payload_size: u64, mode: TransferKind) -> Self {
17+
OffloadMetadata { payload_size, mode }
18+
}
19+
20+
pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
21+
OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both }
22+
}
23+
}
24+
25+
// TODO(Sa4dUs): WIP, rn we just have a naive logic for references
26+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
27+
match ty.kind() {
28+
/*
29+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
30+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
31+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
32+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
33+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
34+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
35+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
36+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
37+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
38+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
39+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
40+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
41+
*/
42+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
43+
/*
44+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
45+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
46+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
47+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
48+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
49+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
50+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
51+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
52+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
53+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
54+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
55+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
56+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
57+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
58+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
59+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
60+
*/
61+
_ => tcx
62+
.layout_of(PseudoCanonicalInput {
63+
typing_env: TypingEnv::fully_monomorphized(),
64+
value: ty,
65+
})
66+
.unwrap()
67+
.size
68+
.bytes(),
69+
}
70+
}

0 commit comments

Comments
 (0)