Skip to content

Commit af672b5

Browse files
author
KDr2
committed
fix a TArray bug
1 parent dd211e9 commit af672b5

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

src/tarray.jl

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ TArray{T,N}(::UndefInitializer, d::Vararg{<:Integer,N}) where {T,N} = TArray{T,N
3737
TArray{T,N}(dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
3838
TArray(T::Type, dim) = TArray(Array{T}(undef, dim))
3939

40+
localize(x) = x
41+
localize(x::AbstractArray) = TArray(x)
4042
getdata(x::TArray) = x.data
4143
tape_copy(x::TArray) = TArray(deepcopy(x.data))
4244

@@ -166,70 +168,70 @@ end
166168
# Other methods from stdlib
167169

168170
Base.view(x::TArray, inds...; kwargs...) =
169-
Base.view(getdata(x), inds...; kwargs...) |> TArray
170-
Base.:-(x::TArray) = (-getdata(x)) |> TArray
171-
Base.transpose(x::TArray) = transpose(getdata(x)) |> TArray
172-
Base.adjoint(x::TArray) = adjoint(getdata(x)) |> TArray
173-
Base.repeat(x::TArray; kw...) = repeat(getdata(x); kw...) |> TArray
171+
Base.view(getdata(x), inds...; kwargs...) |> localize
172+
Base.:-(x::TArray) = (-getdata(x)) |> localize
173+
Base.transpose(x::TArray) = transpose(getdata(x)) |> localize
174+
Base.adjoint(x::TArray) = adjoint(getdata(x)) |> localize
175+
Base.repeat(x::TArray; kw...) = repeat(getdata(x); kw...) |> localize
174176

175177
Base.hcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
176-
hcat(getdata.(xs)...) |> TArray
178+
hcat(getdata.(xs)...) |> localize
177179
Base.vcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
178-
vcat(getdata.(xs)...) |> TArray
180+
vcat(getdata.(xs)...) |> localize
179181
Base.cat(xs::Union{TArray{T,1}, TArray{T,2}}...; dims) where T =
180-
cat(getdata.(xs)...; dims = dims) |> TArray
182+
cat(getdata.(xs)...; dims = dims) |> localize
181183

182184

183-
Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(getdata(x), dims) |> TArray
185+
Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(getdata(x), dims) |> localize
184186
Base.reshape(x::TArray, dims::Tuple{Vararg{Union{Int,Colon}}}) =
185-
reshape(getdata(x), Base._reshape_uncolon(getdata(x), dims)) |> TArray
186-
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(getdata(x), dims) |> TArray
187-
188-
Base.permutedims(x::TArray, perm) = permutedims(getdata(x), perm) |> TArray
189-
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(getdata(x), perm) |> TArray
190-
Base.reverse(x::TArray; dims) = reverse(getdata(x), dims = dims) |> TArray
191-
192-
Base.sum(x::TArray; dims = :) = sum(getdata(x), dims = dims) |> TArray
193-
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(getdata(x))) |> TArray
194-
Base.prod(x::TArray; dims=:) = prod(getdata(x); dims=dims) |> TArray
195-
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(getdata(x))) |> TArray
196-
197-
Base.findfirst(x::TArray, args...) = findfirst(getdata(x), args...) |> TArray
198-
Base.maximum(x::TArray; dims = :) = maximum(getdata(x), dims = dims) |> TArray
199-
Base.minimum(x::TArray; dims = :) = minimum(getdata(x), dims = dims) |> TArray
200-
201-
Base.:/(x::TArray, y::TArray) = getdata(x) / getdata(y) |> TArray
202-
Base.:/(x::AbstractArray, y::TArray) = x / getdata(y) |> TArray
203-
Base.:/(x::TArray, y::AbstractArray) = getdata(x) / y |> TArray
204-
Base.:\(x::TArray, y::TArray) = getdata(x) \ getdata(y) |> TArray
205-
Base.:\(x::AbstractArray, y::TArray) = x \ getdata(y) |> TArray
206-
Base.:\(x::TArray, y::AbstractArray) = getdata(x) \ y |> TArray
207-
Base.:*(x::TArray, y::TArray) = getdata(x) * getdata(y) |> TArray
208-
Base.:*(x::AbstractArray, y::TArray) = x * getdata(y) |> TArray
209-
Base.:*(x::TArray, y::AbstractArray) = getdata(x) * y |> TArray
187+
reshape(getdata(x), Base._reshape_uncolon(getdata(x), dims)) |> localize
188+
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(getdata(x), dims) |> localize
189+
190+
Base.permutedims(x::TArray, perm) = permutedims(getdata(x), perm) |> localize
191+
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(getdata(x), perm) |> localize
192+
Base.reverse(x::TArray; dims) = reverse(getdata(x), dims = dims) |> localize
193+
194+
Base.sum(x::TArray; dims = :) = sum(getdata(x), dims = dims) |> localize
195+
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(getdata(x))) |> localize
196+
Base.prod(x::TArray; dims=:) = prod(getdata(x); dims=dims) |> localize
197+
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(getdata(x))) |> localize
198+
199+
Base.findfirst(x::TArray, args...) = findfirst(getdata(x), args...) |> localize
200+
Base.maximum(x::TArray; dims = :) = maximum(getdata(x), dims = dims) |> localize
201+
Base.minimum(x::TArray; dims = :) = minimum(getdata(x), dims = dims) |> localize
202+
203+
Base.:/(x::TArray, y::TArray) = getdata(x) / getdata(y) |> localize
204+
Base.:/(x::AbstractArray, y::TArray) = x / getdata(y) |> localize
205+
Base.:/(x::TArray, y::AbstractArray) = getdata(x) / y |> localize
206+
Base.:\(x::TArray, y::TArray) = getdata(x) \ getdata(y) |> localize
207+
Base.:\(x::AbstractArray, y::TArray) = x \ getdata(y) |> localize
208+
Base.:\(x::TArray, y::AbstractArray) = getdata(x) \ y |> localize
209+
Base.:*(x::TArray, y::TArray) = getdata(x) * getdata(y) |> localize
210+
Base.:*(x::AbstractArray, y::TArray) = x * getdata(y) |> localize
211+
Base.:*(x::TArray, y::AbstractArray) = getdata(x) * y |> localize
210212

211213
# broadcast
212214
Base.BroadcastStyle(::Type{<:TArray}) = Broadcast.ArrayStyle{TArray}()
213-
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(getdata.(args)...) |> TArray
215+
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(getdata.(args)...) |> localize
214216

215217
import LinearAlgebra
216218
import LinearAlgebra: \, /, inv, det, logdet, logabsdet, norm
217219

218-
LinearAlgebra.inv(x::TArray) = inv(getdata(x)) |> TArray
219-
LinearAlgebra.det(x::TArray) = det(getdata(x)) |> TArray
220-
LinearAlgebra.logdet(x::TArray) = logdet(getdata(x)) |> TArray
221-
LinearAlgebra.logabsdet(x::TArray) = logabsdet(getdata(x)) |> TArray
220+
LinearAlgebra.inv(x::TArray) = inv(getdata(x)) |> localize
221+
LinearAlgebra.det(x::TArray) = det(getdata(x)) |> localize
222+
LinearAlgebra.logdet(x::TArray) = logdet(getdata(x)) |> localize
223+
LinearAlgebra.logabsdet(x::TArray) = logabsdet(getdata(x)) |> localize
222224
LinearAlgebra.norm(x::TArray, p::Real = 2) =
223-
LinearAlgebra.norm(getdata(x), p) |> TArray
225+
LinearAlgebra.norm(getdata(x), p) |> localize
224226

225227
import LinearAlgebra: dot
226-
dot(x::TArray, ys::TArray) = dot(getdata(x), getdata(ys)) |> TArray
227-
dot(x::AbstractArray, ys::TArray) = dot(x, getdata(ys)) |> TArray
228-
dot(x::TArray, ys::AbstractArray) = dot(getdata(x), ys) |> TArray
228+
dot(x::TArray, ys::TArray) = dot(getdata(x), getdata(ys)) |> localize
229+
dot(x::AbstractArray, ys::TArray) = dot(x, getdata(ys)) |> localize
230+
dot(x::TArray, ys::AbstractArray) = dot(getdata(x), ys) |> localize
229231

230232
using Statistics
231-
Statistics.mean(x::TArray; dims = :) = mean(getdata(x), dims = dims) |> TArray
232-
Statistics.std(x::TArray; kw...) = std(getdata(x), kw...) |> TArray
233+
Statistics.mean(x::TArray; dims = :) = mean(getdata(x), dims = dims) |> localize
234+
Statistics.std(x::TArray; kw...) = std(getdata(x), kw...) |> localize
233235

234236
# TODO
235237
# * NNlib

0 commit comments

Comments
 (0)