From 935e5aa5862b0e881703ceb74b43049318b644ab Mon Sep 17 00:00:00 2001 From: norci Date: Wed, 16 Sep 2020 21:36:57 +0800 Subject: [PATCH 1/2] updated send_to_device, for ElasticArray type. --- src/utils/device.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utils/device.jl b/src/utils/device.jl index 8888e57..da41da3 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -1,5 +1,6 @@ export device, send_to_host, send_to_device +using ElasticArrays using Flux using CUDA using Adapt @@ -22,6 +23,8 @@ send_to_device( <:Any, <:Base.ReshapedArray{<:Any,<:Any,<:SubArray{<:Any,<:Any,<:CircularArrayBuffer}}, }, + SubArray{<:Any,<:Any,<:ElasticArray,<:Any,<:Any}, + ElasticArray, }, ) = CuArray(x) From 4dd8fe651ec6640e6ea583577e2ee01ae52e6911 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 21 Sep 2020 17:32:28 +0800 Subject: [PATCH 2/2] recoganize the device of ElasticArray --- src/utils/device.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils/device.jl b/src/utils/device.jl index da41da3..a4d727c 100644 --- a/src/utils/device.jl +++ b/src/utils/device.jl @@ -40,6 +40,7 @@ device(::CuArray) = Val(:gpu) device(::Array) = Val(:cpu) device(x::Tuple{}) = nothing device(x::NamedTuple{(),Tuple{}}) = nothing +device(x::ElasticArray) = device(x.data) function device(x::Random.AbstractRNG) if x isa CUDA.CURAND.RNG