Skip to content

Commit 0b355cc

Browse files
committed
Try without ref proj
1 parent 36aa41a commit 0b355cc

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

src/kernel/param.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ impl<
156156
type AsyncHostType<'stream, 'b> = crate::utils::r#async::AsyncProj<
157157
'b,
158158
'stream,
159-
&'b crate::host::HostAndDeviceConstRef<'b, T>,
159+
crate::host::HostAndDeviceConstRef<'b, T>,
160160
> where Self: 'b;
161161
#[cfg(any(feature = "device", doc))]
162162
type DeviceType<'b> = &'b T where Self: 'b;
@@ -173,8 +173,9 @@ impl<
173173
where
174174
Self: 'b,
175175
{
176+
let _ = stream;
176177
crate::host::HostAndDeviceConstRef::with_new(param, |const_ref| {
177-
inner.with(const_ref.as_async(stream).as_ref())
178+
inner.with(unsafe { crate::utils::r#async::AsyncProj::new(const_ref, None) })
178179
})
179180
}
180181

@@ -257,9 +258,10 @@ impl<
257258
where
258259
Self: 'b,
259260
{
261+
let _ = stream;
260262
// FIXME: forward impl
261263
crate::host::HostAndDeviceConstRef::with_new(param, |const_ref| {
262-
inner.with(const_ref.as_async(stream).as_ref())
264+
inner.with(unsafe { crate::utils::r#async::AsyncProj::new(const_ref, None) })
263265
})
264266
}
265267

@@ -272,7 +274,8 @@ impl<
272274
where
273275
Self: 'b,
274276
{
275-
let param = unsafe { param.unwrap_ref_unchecked() };
277+
let param_ref = param.proj_ref();
278+
let param = unsafe { param_ref.unwrap_ref_unchecked() };
276279
inner(Some(&param_as_raw_bytes(param.for_host())))
277280
}
278281

@@ -360,7 +363,7 @@ impl<
360363
type AsyncHostType<'stream, 'b> = crate::utils::r#async::AsyncProj<
361364
'b,
362365
'stream,
363-
&'b crate::host::HostAndDeviceConstRef<'b, T>
366+
crate::host::HostAndDeviceConstRef<'b, T>
364367
> where Self: 'b;
365368
#[cfg(any(feature = "device", doc))]
366369
type DeviceType<'b> = &'b T where Self: 'b;
@@ -379,8 +382,9 @@ impl<
379382
where
380383
Self: 'b,
381384
{
382-
crate::host::HostAndDeviceMutRef::with_new(param, |const_ref| {
383-
inner.with(const_ref.as_ref().as_async(stream).as_ref())
385+
let _ = stream;
386+
crate::host::HostAndDeviceMutRef::with_new(param, |mut_ref| {
387+
inner.with(unsafe { crate::utils::r#async::AsyncProj::new(mut_ref.as_ref(), None) })
384388
})
385389
}
386390

@@ -580,7 +584,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a DeepPerThreadBorrow<T
580584
type AsyncHostType<'stream, 'b> = crate::utils::r#async::AsyncProj<
581585
'b,
582586
'stream,
583-
&'b crate::host::HostAndDeviceConstRef<
587+
crate::host::HostAndDeviceConstRef<
584588
'b,
585589
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
586590
>,
@@ -601,8 +605,9 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a DeepPerThreadBorrow<T
601605
where
602606
Self: 'b,
603607
{
608+
let _ = stream;
604609
crate::lend::LendToCuda::lend_to_cuda(param, |param| {
605-
inner.with(param.as_async(stream).as_ref())
610+
inner.with(unsafe { crate::utils::r#async::AsyncProj::new(param, None) })
606611
})
607612
}
608613

@@ -663,7 +668,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
663668
type AsyncHostType<'stream, 'b> = crate::utils::r#async::AsyncProj<
664669
'b,
665670
'stream,
666-
&'b mut crate::host::HostAndDeviceMutRef<
671+
crate::host::HostAndDeviceMutRef<
667672
'b,
668673
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
669674
>,
@@ -690,7 +695,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
690695
inner.with({
691696
// Safety: this projection cannot be moved to a different stream
692697
// without first exiting lend_to_cuda_mut and synchronizing
693-
unsafe { crate::utils::r#async::AsyncProj::new(&mut param.into_mut(), None) }
698+
unsafe { crate::utils::r#async::AsyncProj::new(param.into_mut(), None) }
694699
})
695700
})
696701
}
@@ -727,7 +732,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
727732
Self: 'b,
728733
{
729734
param.record_mut_use()?;
730-
let param = unsafe { param.unwrap_unchecked() };
735+
let mut param = unsafe { param.unwrap_unchecked() };
731736
Ok(param.for_device())
732737
}
733738

@@ -858,8 +863,9 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a PtxJit<DeepPerThreadB
858863
Self: 'b,
859864
{
860865
// FIXME: forward impl
866+
let _ = stream;
861867
crate::lend::LendToCuda::lend_to_cuda(param, |param| {
862-
inner.with(param.as_async(stream).as_ref())
868+
inner.with(unsafe { crate::utils::r#async::AsyncProj::new(param, None) })
863869
})
864870
}
865871

@@ -872,7 +878,8 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a PtxJit<DeepPerThreadB
872878
where
873879
Self: 'b,
874880
{
875-
let param = unsafe { param.unwrap_ref_unchecked() };
881+
let param_ref = param.proj_ref();
882+
let param = unsafe { param_ref.unwrap_unchecked() };
876883
inner(Some(&param_as_raw_bytes(param.for_host())))
877884
}
878885

@@ -945,7 +952,7 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
945952
inner.with({
946953
// Safety: this projection cannot be moved to a different stream
947954
// without first exiting lend_to_cuda_mut and synchronizing
948-
unsafe { crate::utils::r#async::AsyncProj::new(&mut param.into_mut(), None) }
955+
unsafe { crate::utils::r#async::AsyncProj::new(param.into_mut(), None) }
949956
})
950957
})
951958
}
@@ -959,7 +966,8 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter
959966
where
960967
Self: 'b,
961968
{
962-
let param = unsafe { param.as_ref().unwrap_unchecked() };
969+
let param_ref = param.proj_ref();
970+
let param = unsafe { param_ref.unwrap_unchecked() };
963971
inner(Some(&param_as_raw_bytes(param.for_host())))
964972
}
965973

src/utils/exchange/wrapper.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ impl<
340340
) -> AsyncProj<
341341
'_,
342342
'stream,
343-
HostAndDeviceMutRef<DeviceAccessible<<T as RustToCuda>::CudaRepresentation>>,
343+
HostAndDeviceMutRef<'_, DeviceAccessible<<T as RustToCuda>::CudaRepresentation>>,
344344
>
345345
where
346346
T: SafeMutableAliasing,

0 commit comments

Comments
 (0)