|
| 1 | +using Cassette |
| 2 | +import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged |
| 3 | + |
| 4 | +""" |
| 5 | +The sparsity pattern. |
| 6 | +
|
| 7 | +- `I`: Input index |
| 8 | +- `J`: Ouput index |
| 9 | +
|
| 10 | +`(i, j)` means the `j`th element of the output depends on |
| 11 | +the `i`th element of the input. Therefore `length(I) == length(J)` |
| 12 | +""" |
| 13 | +struct Sparsity |
| 14 | + m::Int |
| 15 | + n::Int |
| 16 | + I::Vector{Int} # Input |
| 17 | + J::Vector{Int} # Output |
| 18 | +end |
| 19 | + |
| 20 | +using SparseArrays |
| 21 | +SparseArrays.sparse(s::Sparsity) = sparse(s.I, s.J, true, s.m, s.n) |
| 22 | + |
| 23 | +Sparsity(m, n) = Sparsity(m, n, Int[], Int[]) |
| 24 | + |
| 25 | +function Base.push!(S::Sparsity, i::Int, j::Int) |
| 26 | + push!(S.I, i) |
| 27 | + push!(S.J, j) |
| 28 | +end |
| 29 | + |
| 30 | +# Tags: |
| 31 | +struct Input end |
| 32 | +struct Output end |
| 33 | + |
| 34 | +struct ProvinanceSet{T} |
| 35 | + set::T # Set, Array, Int, Tuple, anything! |
| 36 | +end |
| 37 | + |
| 38 | +# note: this is not strictly set union, just some efficient way of concating |
| 39 | +Base.union(p::ProvinanceSet{<:Tuple}, |
| 40 | + q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set..., q.set,)) |
| 41 | +Base.union(p::ProvinanceSet{<:Integer}, |
| 42 | + q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set, q.set...,)) |
| 43 | +Base.union(p::ProvinanceSet{<:Integer}, |
| 44 | + q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set, q.set,)) |
| 45 | +Base.union(p::ProvinanceSet{<:Tuple}, |
| 46 | + q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set..., q.set...,)) |
| 47 | +Base.union(p::ProvinanceSet, |
| 48 | + q::ProvinanceSet) = ProvinanceSet(union(p.set, q.set)) |
| 49 | +Base.union(p::ProvinanceSet, |
| 50 | + q::ProvinanceSet, |
| 51 | + rs::ProvinanceSet...) = union(union(p, q), rs...) |
| 52 | +Base.union(p::ProvinanceSet) = p |
| 53 | + |
| 54 | +function Base.push!(S::Sparsity, i::Int, js::ProvinanceSet) |
| 55 | + for j in js.set |
| 56 | + push!(S, i, j) |
| 57 | + end |
| 58 | +end |
| 59 | + |
| 60 | +Cassette.@context SparsityContext |
| 61 | + |
| 62 | +const TagType = Union{Input, Output, ProvinanceSet} |
| 63 | +Cassette.metadatatype(::Type{<:SparsityContext}, ::DataType) = TagType |
| 64 | +function ismetatype(x, ctx, T) |
| 65 | + hasmetadata(x, ctx) && istagged(x, ctx) && (metadata(x, ctx) isa T) |
| 66 | +end |
| 67 | + |
| 68 | + |
| 69 | +""" |
| 70 | +`sparsity!(f, Y, X, S=Sparsity(length(X), length(Y)))` |
| 71 | +
|
| 72 | +Execute the program that figures out the sparsity pattern of |
| 73 | +the jacobian of the function `f`. |
| 74 | +
|
| 75 | +# Arguments: |
| 76 | +- `f`: the function |
| 77 | +- `Y`: the output array |
| 78 | +- `X`: the input array |
| 79 | +- `S`: (optional) the sparsity pattern |
| 80 | +
|
| 81 | +Returns a `Sparsity` |
| 82 | +""" |
| 83 | +function sparsity!(f!, Y, X, S=Sparsity(length(Y), length(X))) |
| 84 | + |
| 85 | + ctx = SparsityContext(metadata=S) |
| 86 | + ctx = Cassette.enabletagging(ctx, f!) |
| 87 | + ctx = Cassette.disablehooks(ctx) |
| 88 | + |
| 89 | + val = Cassette.overdub(ctx, |
| 90 | + f!, |
| 91 | + tag(Y, ctx, Output()), |
| 92 | + tag(X, ctx, Input())) |
| 93 | + untag(val, ctx), S |
| 94 | +end |
| 95 | + |
| 96 | +# getindex on the input |
| 97 | +function Cassette.overdub(ctx::SparsityContext, |
| 98 | + f::typeof(getindex), |
| 99 | + X::Tagged, |
| 100 | + idx::Int...) |
| 101 | + if ismetatype(X, ctx, Input) |
| 102 | + val = Cassette.fallback(ctx, f, X, idx...) |
| 103 | + i = LinearIndices(untag(X, ctx))[idx...] |
| 104 | + tag(val, ctx, ProvinanceSet(i)) |
| 105 | + else |
| 106 | + Cassette.recurse(ctx, f, X, idx...) |
| 107 | + end |
| 108 | +end |
| 109 | + |
| 110 | +# setindex! on the output |
| 111 | +function Cassette.overdub(ctx::SparsityContext, |
| 112 | + f::typeof(setindex!), |
| 113 | + Y::Tagged, |
| 114 | + val::Tagged, |
| 115 | + idx::Int...) |
| 116 | + S = ctx.metadata |
| 117 | + if ismetatype(Y, ctx, Output) |
| 118 | + set = metadata(val, ctx) |
| 119 | + if set isa ProvinanceSet |
| 120 | + i = LinearIndices(untag(Y, ctx))[idx...] |
| 121 | + push!(S, i, set) |
| 122 | + end |
| 123 | + return Cassette.fallback(ctx, f, Y, val, idx...) |
| 124 | + else |
| 125 | + return Cassette.recurse(ctx, f, Y, val, idx...) |
| 126 | + end |
| 127 | +end |
| 128 | + |
| 129 | +function get_provinance(ctx, arg::Tagged) |
| 130 | + if metadata(arg, ctx) isa ProvinanceSet |
| 131 | + metadata(arg, ctx) |
| 132 | + else |
| 133 | + ProvinanceSet(()) |
| 134 | + end |
| 135 | +end |
| 136 | + |
| 137 | +get_provinance(ctx, arg) = ProvinanceSet(()) |
| 138 | + |
| 139 | +# Any function acting on a value tagged with ProvinanceSet |
| 140 | +function _overdub_union_provinance(ctx::SparsityContext, f, args...) |
| 141 | + idxs = findall(x->ismetatype(x, ctx, ProvinanceSet), args) |
| 142 | + if isempty(idxs) |
| 143 | + Cassette.fallback(ctx, f, args...) |
| 144 | + else |
| 145 | + provinance = union(map(arg->get_provinance(ctx, arg), args[idxs])...) |
| 146 | + val = Cassette.fallback(ctx, f, args...) |
| 147 | + tag(val, ctx, provinance) |
| 148 | + end |
| 149 | +end |
| 150 | + |
| 151 | +function Cassette.overdub(ctx::SparsityContext, |
| 152 | + f, args...) where {A, B, D<:Output} |
| 153 | + if any(x->ismetatype(x, ctx, ProvinanceSet), args) |
| 154 | + _overdub_union_provinance(ctx, f, args...) |
| 155 | + else |
| 156 | + Cassette.recurse(ctx, f, args...) |
| 157 | + end |
| 158 | +end |
| 159 | + |
| 160 | +#= |
| 161 | +# Examples: |
| 162 | +# |
| 163 | +using UnicodePlots |
| 164 | +
|
| 165 | +sspy(s::Sparsity) = spy(sparse(s)) |
| 166 | +
|
| 167 | +julia> sparsity!([0,0,0], [23,53,83]) do Y, X |
| 168 | + Y[:] .= X |
| 169 | + Y == X |
| 170 | + end |
| 171 | +(true, Sparsity([1, 2, 3], [1, 2, 3])) |
| 172 | +
|
| 173 | +julia> sparsity!([0,0,0], [23,53,83]) do Y, X |
| 174 | + for i=1:3 |
| 175 | + for j=i:3 |
| 176 | + Y[j] += X[i] |
| 177 | + end |
| 178 | + end; Y |
| 179 | + end |
| 180 | +([23, 76, 159], Sparsity(3, 3, [1, 2, 3, 2, 3, 3], [1, 1, 1, 2, 2, 3])) |
| 181 | +
|
| 182 | +julia> sspy(ans[2]) |
| 183 | + Sparsity Pattern |
| 184 | + ┌─────┐ |
| 185 | + 1 │⠀⠄⠀⠀⠀│ > 0 |
| 186 | + 3 │⠀⠅⠨⠠⠀│ < 0 |
| 187 | + └─────┘ |
| 188 | + 1 3 |
| 189 | + nz = 6 |
| 190 | +
|
| 191 | +julia> sparsity!(f, zeros(Int, 3,3), [23,53,83]) |
| 192 | +([23, 53, 83], Sparsity(9, 3, [2, 5, 8], [1, 2, 3])) |
| 193 | +
|
| 194 | +julia> sspy(ans[2]) |
| 195 | + Sparsity Pattern |
| 196 | + ┌─────┐ |
| 197 | + 1 │⠀⠄⠀⠀⠀│ > 0 |
| 198 | + │⠀⠀⠠⠀⠀│ < 0 |
| 199 | + 9 │⠀⠀⠀⠐⠀│ |
| 200 | + └─────┘ |
| 201 | + 1 3 |
| 202 | + nz = 3 |
| 203 | +=# |
0 commit comments