diff --git a/src/utils/device.jl b/src/utils/device.jl index 8888e57..a4d727c 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -1,5 +1,6 @@ export device, send_to_host, send_to_device +using ElasticArrays using Flux using CUDA using Adapt @@ -22,6 +23,8 @@ send_to_device( <:Any, <:Base.ReshapedArray{<:Any,<:Any,<:SubArray{<:Any,<:Any,<:CircularArrayBuffer}}, }, + SubArray{<:Any,<:Any,<:ElasticArray,<:Any,<:Any}, + ElasticArray, }, ) = CuArray(x) @@ -37,6 +40,7 @@ device(::CuArray) = Val(:gpu) device(::Array) = Val(:cpu) device(x::Tuple{}) = nothing device(x::NamedTuple{(),Tuple{}}) = nothing +device(x::ElasticArray) = device(x.data) function device(x::Random.AbstractRNG) if x isa CUDA.CURAND.RNG