@@ -37,6 +37,8 @@ TArray{T,N}(::UndefInitializer, d::Vararg{<:Integer,N}) where {T,N} = TArray{T,N
3737TArray {T,N} (dim:: NTuple{N,Int} ) where {T,N} = TArray (T, dim)
3838TArray (T:: Type , dim) = TArray (Array {T} (undef, dim))
3939
40+ localize (x) = x
41+ localize (x:: AbstractArray ) = TArray (x)
4042getdata (x:: TArray ) = x. data
4143tape_copy (x:: TArray ) = TArray (deepcopy (x. data))
4244
@@ -166,70 +168,70 @@ end
166168# Other methods from stdlib
167169
168170Base. view (x:: TArray , inds... ; kwargs... ) =
169- Base. view (getdata (x), inds... ; kwargs... ) |> TArray
170- Base.:- (x:: TArray ) = (- getdata (x)) |> TArray
171- Base. transpose (x:: TArray ) = transpose (getdata (x)) |> TArray
172- Base. adjoint (x:: TArray ) = adjoint (getdata (x)) |> TArray
173- Base. repeat (x:: TArray ; kw... ) = repeat (getdata (x); kw... ) |> TArray
171+ Base. view (getdata (x), inds... ; kwargs... ) |> localize
172+ Base.:- (x:: TArray ) = (- getdata (x)) |> localize
173+ Base. transpose (x:: TArray ) = transpose (getdata (x)) |> localize
174+ Base. adjoint (x:: TArray ) = adjoint (getdata (x)) |> localize
175+ Base. repeat (x:: TArray ; kw... ) = repeat (getdata (x); kw... ) |> localize
174176
175177Base. hcat (xs:: Union{TArray{T,1}, TArray{T,2}} ...) where T =
176- hcat (getdata .(xs)... ) |> TArray
178+ hcat (getdata .(xs)... ) |> localize
177179Base. vcat (xs:: Union{TArray{T,1}, TArray{T,2}} ...) where T =
178- vcat (getdata .(xs)... ) |> TArray
180+ vcat (getdata .(xs)... ) |> localize
179181Base. cat (xs:: Union{TArray{T,1}, TArray{T,2}} ...; dims) where T =
180- cat (getdata .(xs)... ; dims = dims) |> TArray
182+ cat (getdata .(xs)... ; dims = dims) |> localize
181183
182184
183- Base. reshape (x:: TArray , dims:: Union{Colon,Int} ...) = reshape (getdata (x), dims) |> TArray
185+ Base. reshape (x:: TArray , dims:: Union{Colon,Int} ...) = reshape (getdata (x), dims) |> localize
184186Base. reshape (x:: TArray , dims:: Tuple{Vararg{Union{Int,Colon}}} ) =
185- reshape (getdata (x), Base. _reshape_uncolon (getdata (x), dims)) |> TArray
186- Base. reshape (x:: TArray , dims:: Tuple{Vararg{Int}} ) = reshape (getdata (x), dims) |> TArray
187-
188- Base. permutedims (x:: TArray , perm) = permutedims (getdata (x), perm) |> TArray
189- Base. PermutedDimsArray (x:: TArray , perm) = PermutedDimsArray (getdata (x), perm) |> TArray
190- Base. reverse (x:: TArray ; dims) = reverse (getdata (x), dims = dims) |> TArray
191-
192- Base. sum (x:: TArray ; dims = :) = sum (getdata (x), dims = dims) |> TArray
193- Base. sum (f:: Union{Function,Type} ,x:: TArray ) = sum (f .(getdata (x))) |> TArray
194- Base. prod (x:: TArray ; dims= :) = prod (getdata (x); dims= dims) |> TArray
195- Base. prod (f:: Union{Function, Type} , x:: TArray ) = prod (f .(getdata (x))) |> TArray
196-
197- Base. findfirst (x:: TArray , args... ) = findfirst (getdata (x), args... ) |> TArray
198- Base. maximum (x:: TArray ; dims = :) = maximum (getdata (x), dims = dims) |> TArray
199- Base. minimum (x:: TArray ; dims = :) = minimum (getdata (x), dims = dims) |> TArray
200-
201- Base.:/ (x:: TArray , y:: TArray ) = getdata (x) / getdata (y) |> TArray
202- Base.:/ (x:: AbstractArray , y:: TArray ) = x / getdata (y) |> TArray
203- Base.:/ (x:: TArray , y:: AbstractArray ) = getdata (x) / y |> TArray
204- Base.:\ (x:: TArray , y:: TArray ) = getdata (x) \ getdata (y) |> TArray
205- Base.:\ (x:: AbstractArray , y:: TArray ) = x \ getdata (y) |> TArray
206- Base.:\ (x:: TArray , y:: AbstractArray ) = getdata (x) \ y |> TArray
207- Base.:* (x:: TArray , y:: TArray ) = getdata (x) * getdata (y) |> TArray
208- Base.:* (x:: AbstractArray , y:: TArray ) = x * getdata (y) |> TArray
209- Base.:* (x:: TArray , y:: AbstractArray ) = getdata (x) * y |> TArray
187+ reshape (getdata (x), Base. _reshape_uncolon (getdata (x), dims)) |> localize
188+ Base. reshape (x:: TArray , dims:: Tuple{Vararg{Int}} ) = reshape (getdata (x), dims) |> localize
189+
190+ Base. permutedims (x:: TArray , perm) = permutedims (getdata (x), perm) |> localize
191+ Base. PermutedDimsArray (x:: TArray , perm) = PermutedDimsArray (getdata (x), perm) |> localize
192+ Base. reverse (x:: TArray ; dims) = reverse (getdata (x), dims = dims) |> localize
193+
194+ Base. sum (x:: TArray ; dims = :) = sum (getdata (x), dims = dims) |> localize
195+ Base. sum (f:: Union{Function,Type} ,x:: TArray ) = sum (f .(getdata (x))) |> localize
196+ Base. prod (x:: TArray ; dims= :) = prod (getdata (x); dims= dims) |> localize
197+ Base. prod (f:: Union{Function, Type} , x:: TArray ) = prod (f .(getdata (x))) |> localize
198+
199+ Base. findfirst (x:: TArray , args... ) = findfirst (getdata (x), args... ) |> localize
200+ Base. maximum (x:: TArray ; dims = :) = maximum (getdata (x), dims = dims) |> localize
201+ Base. minimum (x:: TArray ; dims = :) = minimum (getdata (x), dims = dims) |> localize
202+
203+ Base.:/ (x:: TArray , y:: TArray ) = getdata (x) / getdata (y) |> localize
204+ Base.:/ (x:: AbstractArray , y:: TArray ) = x / getdata (y) |> localize
205+ Base.:/ (x:: TArray , y:: AbstractArray ) = getdata (x) / y |> localize
206+ Base.:\ (x:: TArray , y:: TArray ) = getdata (x) \ getdata (y) |> localize
207+ Base.:\ (x:: AbstractArray , y:: TArray ) = x \ getdata (y) |> localize
208+ Base.:\ (x:: TArray , y:: AbstractArray ) = getdata (x) \ y |> localize
209+ Base.:* (x:: TArray , y:: TArray ) = getdata (x) * getdata (y) |> localize
210+ Base.:* (x:: AbstractArray , y:: TArray ) = x * getdata (y) |> localize
211+ Base.:* (x:: TArray , y:: AbstractArray ) = getdata (x) * y |> localize
210212
211213# broadcast
212214Base. BroadcastStyle (:: Type{<:TArray} ) = Broadcast. ArrayStyle {TArray} ()
213- Broadcast. broadcasted (:: Broadcast.ArrayStyle{TArray} , f, args... ) = f .(getdata .(args)... ) |> TArray
215+ Broadcast. broadcasted (:: Broadcast.ArrayStyle{TArray} , f, args... ) = f .(getdata .(args)... ) |> localize
214216
215217import LinearAlgebra
216218import LinearAlgebra: \ , / , inv, det, logdet, logabsdet, norm
217219
218- LinearAlgebra. inv (x:: TArray ) = inv (getdata (x)) |> TArray
219- LinearAlgebra. det (x:: TArray ) = det (getdata (x)) |> TArray
220- LinearAlgebra. logdet (x:: TArray ) = logdet (getdata (x)) |> TArray
221- LinearAlgebra. logabsdet (x:: TArray ) = logabsdet (getdata (x)) |> TArray
220+ LinearAlgebra. inv (x:: TArray ) = inv (getdata (x)) |> localize
221+ LinearAlgebra. det (x:: TArray ) = det (getdata (x)) |> localize
222+ LinearAlgebra. logdet (x:: TArray ) = logdet (getdata (x)) |> localize
223+ LinearAlgebra. logabsdet (x:: TArray ) = logabsdet (getdata (x)) |> localize
222224LinearAlgebra. norm (x:: TArray , p:: Real = 2 ) =
223- LinearAlgebra. norm (getdata (x), p) |> TArray
225+ LinearAlgebra. norm (getdata (x), p) |> localize
224226
225227import LinearAlgebra: dot
226- dot (x:: TArray , ys:: TArray ) = dot (getdata (x), getdata (ys)) |> TArray
227- dot (x:: AbstractArray , ys:: TArray ) = dot (x, getdata (ys)) |> TArray
228- dot (x:: TArray , ys:: AbstractArray ) = dot (getdata (x), ys) |> TArray
228+ dot (x:: TArray , ys:: TArray ) = dot (getdata (x), getdata (ys)) |> localize
229+ dot (x:: AbstractArray , ys:: TArray ) = dot (x, getdata (ys)) |> localize
230+ dot (x:: TArray , ys:: AbstractArray ) = dot (getdata (x), ys) |> localize
229231
230232using Statistics
231- Statistics. mean (x:: TArray ; dims = :) = mean (getdata (x), dims = dims) |> TArray
232- Statistics. std (x:: TArray ; kw... ) = std (getdata (x), kw... ) |> TArray
233+ Statistics. mean (x:: TArray ; dims = :) = mean (getdata (x), dims = dims) |> localize
234+ Statistics. std (x:: TArray ; kw... ) = std (getdata (x), kw... ) |> localize
233235
234236# TODO
235237# * NNlib
0 commit comments