Skip to content

Commit 36aa41a

Browse files
committed
Fix uniqueness guarantee for Stream using branded types
1 parent 139adce commit 36aa41a

File tree

18 files changed

+119
-95
lines changed

18 files changed

+119
-95
lines changed

rust-cuda-derive/src/rust_to_cuda/impl.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ pub fn rust_to_cuda_async_trait(
191191
unsafe fn borrow_async<'stream, CudaAllocType: #crate_path::alloc::CudaAlloc>(
192192
&self,
193193
alloc: CudaAllocType,
194-
stream: &'stream #crate_path::host::Stream,
194+
stream: #crate_path::host::Stream<'stream>,
195195
) -> #crate_path::deps::rustacuda::error::CudaResult<(
196196
#crate_path::utils::r#async::Async<
197197
'_, 'stream,
@@ -219,7 +219,7 @@ pub fn rust_to_cuda_async_trait(
219219
alloc: #crate_path::alloc::CombinedCudaAlloc<
220220
Self::CudaAllocationAsync, CudaAllocType
221221
>,
222-
stream: &'stream #crate_path::host::Stream,
222+
stream: #crate_path::host::Stream<'stream>,
223223
) -> #crate_path::deps::rustacuda::error::CudaResult<(
224224
#crate_path::utils::r#async::Async<
225225
'a, 'stream,

src/host/mod.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,54 @@ use crate::{
2525
},
2626
};
2727

28+
type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
29+
30+
#[derive(Copy, Clone)]
2831
#[repr(transparent)]
29-
pub struct Stream {
30-
stream: rustacuda::stream::Stream,
32+
pub struct Stream<'stream> {
33+
stream: &'stream rustacuda::stream::Stream,
34+
_brand: InvariantLifetime<'stream>,
3135
}
3236

33-
impl Deref for Stream {
37+
impl<'stream> Deref for Stream<'stream> {
3438
type Target = rustacuda::stream::Stream;
3539

3640
fn deref(&self) -> &Self::Target {
37-
&self.stream
41+
self.stream
3842
}
3943
}
4044

41-
impl Stream {
45+
impl<'stream> Stream<'stream> {
46+
#[allow(clippy::needless_pass_by_ref_mut)]
47+
/// Create a new uniquely branded [`Stream`], which can bind async
48+
/// operations to the [`Stream`] that they are computed on.
49+
///
50+
/// The uniqueness guarantees are provided by using branded types,
51+
/// as inspired by the Ghost Cell paper by Yanovski, J., Dang, H.-H.,
52+
/// Jung, R., and Dreyer, D.: <https://doi.org/10.1145/3473597>.
53+
///
54+
/// # Examples
55+
///
56+
/// The following example shows that two [`Stream`]'s with different
57+
/// `'stream` lifetime brands cannot be used interchangeably.
58+
///
59+
/// ```rust, compile_fail
60+
/// use rust_cuda::host::Stream;
61+
///
62+
/// fn check_same<'stream>(_stream_a: Stream<'stream>, _stream_b: Stream<'stream>) {}
63+
///
64+
/// fn two_streams<'stream_a, 'stream_b>(stream_a: Stream<'stream_a>, stream_b: Stream<'stream_b>) {
65+
/// check_same(stream_a, stream_b);
66+
/// }
67+
/// ```
4268
pub fn with<O>(
4369
stream: &mut rustacuda::stream::Stream,
44-
inner: impl for<'stream> FnOnce(&'stream Self) -> O,
70+
inner: impl for<'new_stream> FnOnce(Stream<'new_stream>) -> O,
4571
) -> O {
46-
// Safety:
47-
// - Stream is a newtype wrapper around rustacuda::stream::Stream
48-
// - we forge a unique lifetime for a unique reference
49-
let stream = unsafe { &*std::ptr::from_ref(stream).cast() };
50-
51-
inner(stream)
72+
inner(Stream {
73+
stream,
74+
_brand: InvariantLifetime::default(),
75+
})
5276
}
5377
}
5478

@@ -219,7 +243,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
219243
#[must_use]
220244
pub fn into_async<'b, 'stream>(
221245
self,
222-
stream: &'stream Stream,
246+
stream: Stream<'stream>,
223247
) -> Async<'b, 'stream, HostAndDeviceMutRef<'b, T>, NoCompletion>
224248
where
225249
'a: 'b,
@@ -312,7 +336,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T>
312336
#[must_use]
313337
pub const fn as_async<'b, 'stream>(
314338
&'b self,
315-
stream: &'stream Stream,
339+
stream: Stream<'stream>,
316340
) -> Async<'b, 'stream, HostAndDeviceConstRef<'b, T>, NoCompletion>
317341
where
318342
'a: 'b,
@@ -370,7 +394,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceOwned<'a, T> {
370394
#[must_use]
371395
pub const fn into_async<'stream>(
372396
self,
373-
stream: &'stream Stream,
397+
stream: Stream<'stream>,
374398
) -> Async<'a, 'stream, Self, NoCompletion> {
375399
Async::ready(self, stream)
376400
}

src/kernel/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
110110
#[allow(clippy::missing_errors_doc)] // FIXME
111111
fn with_new_async<'stream, 'param, O, E: From<rustacuda::error::CudaError>>(
112112
param: Self::SyncHostType,
113-
stream: &'stream crate::host::Stream,
113+
stream: crate::host::Stream<'stream>,
114114
inner: impl WithNewAsync<'stream, Self, O, E>,
115115
) -> Result<O, E>
116116
where
@@ -156,7 +156,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
156156

157157
#[cfg(feature = "host")]
158158
pub struct Launcher<'stream, 'kernel, Kernel> {
159-
pub stream: &'stream Stream,
159+
pub stream: Stream<'stream>,
160160
pub kernel: &'kernel mut TypedPtxKernel<Kernel>,
161161
pub config: LaunchConfig,
162162
}
@@ -366,7 +366,7 @@ macro_rules! impl_typed_kernel_launch {
366366
#[allow(clippy::too_many_arguments)] // func is defined for <= 12 args
367367
pub fn $launch<'kernel, 'stream, $($T: CudaKernelParameter),*>(
368368
&'kernel mut self,
369-
stream: &'stream Stream,
369+
stream: Stream<'stream>,
370370
config: &LaunchConfig,
371371
$($arg: $T::SyncHostType),*
372372
) -> CudaResult<()>
@@ -396,12 +396,12 @@ macro_rules! impl_typed_kernel_launch {
396396
$($T: CudaKernelParameter),*
397397
>(
398398
&'kernel mut self,
399-
stream: &'stream Stream,
399+
stream: Stream<'stream>,
400400
config: &LaunchConfig,
401401
$($arg: $T::SyncHostType,)*
402402
inner: impl FnOnce(
403403
&'kernel mut Self,
404-
&'stream Stream,
404+
Stream<'stream>,
405405
&LaunchConfig,
406406
$($T::AsyncHostType<'stream, '_>),*
407407
) -> Result<Ok, Err>,
@@ -419,7 +419,7 @@ macro_rules! impl_typed_kernel_launch {
419419
#[allow(clippy::too_many_arguments)] // func is defined for <= 12 args
420420
pub fn $launch_async<'kernel, 'stream, $($T: CudaKernelParameter),*>(
421421
&'kernel mut self,
422-
stream: &'stream Stream,
422+
stream: Stream<'stream>,
423423
config: &LaunchConfig,
424424
$($arg: $T::AsyncHostType<'stream, '_>),*
425425
) -> CudaResult<crate::utils::r#async::Async<

src/kernel/param.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl<
8181
#[cfg(feature = "host")]
8282
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
8383
param: Self::SyncHostType,
84-
_stream: &'stream crate::host::Stream,
84+
_stream: crate::host::Stream<'stream>,
8585
inner: impl super::WithNewAsync<'stream, Self, O, E>,
8686
) -> Result<O, E>
8787
where
@@ -167,7 +167,7 @@ impl<
167167
#[cfg(feature = "host")]
168168
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
169169
param: Self::SyncHostType,
170-
stream: &'stream crate::host::Stream,
170+
stream: crate::host::Stream<'stream>,
171171
inner: impl super::WithNewAsync<'stream, Self, O, E>,
172172
) -> Result<O, E>
173173
where
@@ -251,7 +251,7 @@ impl<
251251
#[cfg(feature = "host")]
252252
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
253253
param: Self::SyncHostType,
254-
stream: &'stream crate::host::Stream,
254+
stream: crate::host::Stream<'stream>,
255255
inner: impl super::WithNewAsync<'stream, Self, O, E>,
256256
) -> Result<O, E>
257257
where
@@ -373,7 +373,7 @@ impl<
373373
#[cfg(feature = "host")]
374374
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
375375
param: Self::SyncHostType,
376-
stream: &'stream crate::host::Stream,
376+
stream: crate::host::Stream<'stream>,
377377
inner: impl super::WithNewAsync<'stream, Self, O, E>,
378378
) -> Result<O, E>
379379
where
@@ -509,7 +509,7 @@ impl<
509509
#[cfg(feature = "host")]
510510
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
511511
param: Self::SyncHostType,
512-
stream: &'stream crate::host::Stream,
512+
stream: crate::host::Stream<'stream>,
513513
inner: impl super::WithNewAsync<'stream, Self, O, E>,
514514
) -> Result<O, E>
515515
where
@@ -595,7 +595,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a DeepPerThreadBorrow<T
595595
#[cfg(feature = "host")]
596596
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
597597
param: Self::SyncHostType,
598-
stream: &'stream crate::host::Stream,
598+
stream: crate::host::Stream<'stream>,
599599
inner: impl super::WithNewAsync<'stream, Self, O, E>,
600600
) -> Result<O, E>
601601
where
@@ -678,7 +678,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
678678
#[cfg(feature = "host")]
679679
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
680680
param: Self::SyncHostType,
681-
stream: &'stream crate::host::Stream,
681+
stream: crate::host::Stream<'stream>,
682682
inner: impl super::WithNewAsync<'stream, Self, O, E>,
683683
) -> Result<O, E>
684684
where
@@ -768,7 +768,7 @@ impl<
768768
#[cfg(feature = "host")]
769769
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
770770
param: Self::SyncHostType,
771-
stream: &'stream crate::host::Stream,
771+
stream: crate::host::Stream<'stream>,
772772
inner: impl super::WithNewAsync<'stream, Self, O, E>,
773773
) -> Result<O, E>
774774
where
@@ -851,7 +851,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a PtxJit<DeepPerThreadB
851851
#[cfg(feature = "host")]
852852
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
853853
param: Self::SyncHostType,
854-
stream: &'stream crate::host::Stream,
854+
stream: crate::host::Stream<'stream>,
855855
inner: impl super::WithNewAsync<'stream, Self, O, E>,
856856
) -> Result<O, E>
857857
where
@@ -932,7 +932,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
932932
#[cfg(feature = "host")]
933933
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
934934
param: Self::SyncHostType,
935-
stream: &'stream crate::host::Stream,
935+
stream: crate::host::Stream<'stream>,
936936
inner: impl super::WithNewAsync<'stream, Self, O, E>,
937937
) -> Result<O, E>
938938
where
@@ -1058,7 +1058,7 @@ impl<'a, T: 'static> CudaKernelParameter for &'a mut crate::utils::shared::Threa
10581058
#[cfg(feature = "host")]
10591059
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
10601060
param: Self::SyncHostType,
1061-
_stream: &'stream crate::host::Stream,
1061+
_stream: crate::host::Stream<'stream>,
10621062
inner: impl super::WithNewAsync<'stream, Self, O, E>,
10631063
) -> Result<O, E>
10641064
where
@@ -1135,7 +1135,7 @@ impl<'a, T: 'static + PortableBitSemantics + TypeGraphLayout> CudaKernelParamete
11351135
#[cfg(feature = "host")]
11361136
fn with_new_async<'stream, 'b, O, E: From<rustacuda::error::CudaError>>(
11371137
param: Self::SyncHostType,
1138-
_stream: &'stream crate::host::Stream,
1138+
_stream: crate::host::Stream<'stream>,
11391139
inner: impl super::WithNewAsync<'stream, Self, O, E>,
11401140
) -> Result<O, E>
11411141
where

src/lend/impls/box.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<T
9090
unsafe fn borrow_async<'stream, A: CudaAlloc>(
9191
&self,
9292
alloc: A,
93-
stream: &'stream crate::host::Stream,
93+
stream: crate::host::Stream<'stream>,
9494
) -> rustacuda::error::CudaResult<(
9595
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
9696
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
@@ -113,7 +113,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<T
113113
let mut device_box = CudaDropWrapper::from(DeviceBox::<
114114
DeviceCopyWithPortableBitSemantics<ManuallyDrop<T>>,
115115
>::uninitialized()?);
116-
device_box.async_copy_from(&*locked_box, stream)?;
116+
device_box.async_copy_from(&*locked_box, &stream)?;
117117

118118
Ok((
119119
Async::pending(
@@ -131,7 +131,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<T
131131
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
132132
this: owning_ref::BoxRefMut<'a, O, Self>,
133133
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
134-
stream: &'stream crate::host::Stream,
134+
stream: crate::host::Stream<'stream>,
135135
) -> CudaResult<(
136136
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
137137
A,
@@ -141,7 +141,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<T
141141
let (alloc_front, alloc_tail) = alloc.split();
142142
let (mut locked_box, device_box) = alloc_front.split();
143143

144-
device_box.async_copy_to(&mut *locked_box, stream)?;
144+
device_box.async_copy_to(&mut *locked_box, &stream)?;
145145

146146
let r#async = crate::utils::r#async::Async::<_, CompletionFnMut<'a, Self>>::pending(
147147
this,

src/lend/impls/boxed_slice.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<[
9696
unsafe fn borrow_async<'stream, A: CudaAlloc>(
9797
&self,
9898
alloc: A,
99-
stream: &'stream crate::host::Stream,
99+
stream: crate::host::Stream<'stream>,
100100
) -> rustacuda::error::CudaResult<(
101101
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
102102
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
@@ -120,7 +120,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<[
120120
let mut device_buffer = CudaDropWrapper::from(DeviceBuffer::<
121121
DeviceCopyWithPortableBitSemantics<ManuallyDrop<T>>,
122122
>::uninitialized(self.len())?);
123-
device_buffer.async_copy_from(&*locked_buffer, stream)?;
123+
device_buffer.async_copy_from(&*locked_buffer, &stream)?;
124124

125125
Ok((
126126
Async::pending(
@@ -140,7 +140,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<[
140140
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
141141
this: owning_ref::BoxRefMut<'a, O, Self>,
142142
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
143-
stream: &'stream crate::host::Stream,
143+
stream: crate::host::Stream<'stream>,
144144
) -> CudaResult<(
145145
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
146146
A,
@@ -150,7 +150,7 @@ unsafe impl<T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for Box<[
150150
let (alloc_front, alloc_tail) = alloc.split();
151151
let (mut locked_buffer, device_buffer) = alloc_front.split();
152152

153-
device_buffer.async_copy_to(&mut *locked_buffer, stream)?;
153+
device_buffer.async_copy_to(&mut *locked_buffer, &stream)?;
154154

155155
let r#async = crate::utils::r#async::Async::<_, CompletionFnMut<'a, Self>>::pending(
156156
this,

src/lend/impls/final.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ unsafe impl<T: RustToCudaAsync> RustToCudaAsync for Final<T> {
4949
unsafe fn borrow_async<'stream, A: crate::alloc::CudaAlloc>(
5050
&self,
5151
alloc: A,
52-
stream: &'stream crate::host::Stream,
52+
stream: crate::host::Stream<'stream>,
5353
) -> rustacuda::error::CudaResult<(
5454
crate::utils::r#async::Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
5555
crate::alloc::CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
@@ -76,7 +76,7 @@ unsafe impl<T: RustToCudaAsync> RustToCudaAsync for Final<T> {
7676
unsafe fn restore_async<'a, 'stream, A: crate::alloc::CudaAlloc, O>(
7777
this: owning_ref::BoxRefMut<'a, O, Self>,
7878
alloc: crate::alloc::CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
79-
stream: &'stream crate::host::Stream,
79+
stream: crate::host::Stream<'stream>,
8080
) -> rustacuda::error::CudaResult<(
8181
crate::utils::r#async::Async<
8282
'a,

src/lend/impls/option.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ unsafe impl<T: RustToCudaAsync> RustToCudaAsync for Option<T> {
8989
unsafe fn borrow_async<'stream, A: CudaAlloc>(
9090
&self,
9191
alloc: A,
92-
stream: &'stream crate::host::Stream,
92+
stream: crate::host::Stream<'stream>,
9393
) -> CudaResult<(
9494
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
9595
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
@@ -135,7 +135,7 @@ unsafe impl<T: RustToCudaAsync> RustToCudaAsync for Option<T> {
135135
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
136136
mut this: owning_ref::BoxRefMut<'a, O, Self>,
137137
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
138-
stream: &'stream crate::host::Stream,
138+
stream: crate::host::Stream<'stream>,
139139
) -> CudaResult<(
140140
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
141141
A,

0 commit comments

Comments
 (0)