Skip to content

Commit 6fdb20f

Browse files
Merge pull request #9 from shashi/s/program-sparsity
sparsity of the jacobian of a program
2 parents 6ddb94e + be2fc7b commit 6fdb20f

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

src/program_sparsity.jl

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
using Cassette
2+
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged
3+
4+
"""
5+
The sparsity pattern.
6+
7+
- `I`: Input index
8+
- `J`: Ouput index
9+
10+
`(i, j)` means the `j`th element of the output depends on
11+
the `i`th element of the input. Therefore `length(I) == length(J)`
12+
"""
13+
struct Sparsity
14+
m::Int
15+
n::Int
16+
I::Vector{Int} # Input
17+
J::Vector{Int} # Output
18+
end
19+
20+
using SparseArrays
21+
SparseArrays.sparse(s::Sparsity) = sparse(s.I, s.J, true, s.m, s.n)
22+
23+
Sparsity(m, n) = Sparsity(m, n, Int[], Int[])
24+
25+
function Base.push!(S::Sparsity, i::Int, j::Int)
26+
push!(S.I, i)
27+
push!(S.J, j)
28+
end
29+
30+
# Tags:
31+
struct Input end
32+
struct Output end
33+
34+
struct ProvinanceSet{T}
35+
set::T # Set, Array, Int, Tuple, anything!
36+
end
37+
38+
# note: this is not strictly set union, just some efficient way of concating
39+
Base.union(p::ProvinanceSet{<:Tuple},
40+
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set..., q.set,))
41+
Base.union(p::ProvinanceSet{<:Integer},
42+
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set, q.set...,))
43+
Base.union(p::ProvinanceSet{<:Integer},
44+
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set, q.set,))
45+
Base.union(p::ProvinanceSet{<:Tuple},
46+
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set..., q.set...,))
47+
Base.union(p::ProvinanceSet,
48+
q::ProvinanceSet) = ProvinanceSet(union(p.set, q.set))
49+
Base.union(p::ProvinanceSet,
50+
q::ProvinanceSet,
51+
rs::ProvinanceSet...) = union(union(p, q), rs...)
52+
Base.union(p::ProvinanceSet) = p
53+
54+
function Base.push!(S::Sparsity, i::Int, js::ProvinanceSet)
55+
for j in js.set
56+
push!(S, i, j)
57+
end
58+
end
59+
60+
Cassette.@context SparsityContext
61+
62+
const TagType = Union{Input, Output, ProvinanceSet}
63+
Cassette.metadatatype(::Type{<:SparsityContext}, ::DataType) = TagType
64+
function ismetatype(x, ctx, T)
65+
hasmetadata(x, ctx) && istagged(x, ctx) && (metadata(x, ctx) isa T)
66+
end
67+
68+
69+
"""
70+
`sparsity!(f, Y, X, S=Sparsity(length(X), length(Y)))`
71+
72+
Execute the program that figures out the sparsity pattern of
73+
the jacobian of the function `f`.
74+
75+
# Arguments:
76+
- `f`: the function
77+
- `Y`: the output array
78+
- `X`: the input array
79+
- `S`: (optional) the sparsity pattern
80+
81+
Returns a `Sparsity`
82+
"""
83+
function sparsity!(f!, Y, X, S=Sparsity(length(Y), length(X)))
84+
85+
ctx = SparsityContext(metadata=S)
86+
ctx = Cassette.enabletagging(ctx, f!)
87+
ctx = Cassette.disablehooks(ctx)
88+
89+
val = Cassette.overdub(ctx,
90+
f!,
91+
tag(Y, ctx, Output()),
92+
tag(X, ctx, Input()))
93+
untag(val, ctx), S
94+
end
95+
96+
# getindex on the input
97+
function Cassette.overdub(ctx::SparsityContext,
98+
f::typeof(getindex),
99+
X::Tagged,
100+
idx::Int...)
101+
if ismetatype(X, ctx, Input)
102+
val = Cassette.fallback(ctx, f, X, idx...)
103+
i = LinearIndices(untag(X, ctx))[idx...]
104+
tag(val, ctx, ProvinanceSet(i))
105+
else
106+
Cassette.recurse(ctx, f, X, idx...)
107+
end
108+
end
109+
110+
# setindex! on the output
111+
function Cassette.overdub(ctx::SparsityContext,
112+
f::typeof(setindex!),
113+
Y::Tagged,
114+
val::Tagged,
115+
idx::Int...)
116+
S = ctx.metadata
117+
if ismetatype(Y, ctx, Output)
118+
set = metadata(val, ctx)
119+
if set isa ProvinanceSet
120+
i = LinearIndices(untag(Y, ctx))[idx...]
121+
push!(S, i, set)
122+
end
123+
return Cassette.fallback(ctx, f, Y, val, idx...)
124+
else
125+
return Cassette.recurse(ctx, f, Y, val, idx...)
126+
end
127+
end
128+
129+
function get_provinance(ctx, arg::Tagged)
130+
if metadata(arg, ctx) isa ProvinanceSet
131+
metadata(arg, ctx)
132+
else
133+
ProvinanceSet(())
134+
end
135+
end
136+
137+
get_provinance(ctx, arg) = ProvinanceSet(())
138+
139+
# Any function acting on a value tagged with ProvinanceSet
140+
function _overdub_union_provinance(ctx::SparsityContext, f, args...)
141+
idxs = findall(x->ismetatype(x, ctx, ProvinanceSet), args)
142+
if isempty(idxs)
143+
Cassette.fallback(ctx, f, args...)
144+
else
145+
provinance = union(map(arg->get_provinance(ctx, arg), args[idxs])...)
146+
val = Cassette.fallback(ctx, f, args...)
147+
tag(val, ctx, provinance)
148+
end
149+
end
150+
151+
function Cassette.overdub(ctx::SparsityContext,
152+
f, args...) where {A, B, D<:Output}
153+
if any(x->ismetatype(x, ctx, ProvinanceSet), args)
154+
_overdub_union_provinance(ctx, f, args...)
155+
else
156+
Cassette.recurse(ctx, f, args...)
157+
end
158+
end
159+
160+
#=
161+
# Examples:
162+
#
163+
using UnicodePlots
164+
165+
sspy(s::Sparsity) = spy(sparse(s))
166+
167+
julia> sparsity!([0,0,0], [23,53,83]) do Y, X
168+
Y[:] .= X
169+
Y == X
170+
end
171+
(true, Sparsity([1, 2, 3], [1, 2, 3]))
172+
173+
julia> sparsity!([0,0,0], [23,53,83]) do Y, X
174+
for i=1:3
175+
for j=i:3
176+
Y[j] += X[i]
177+
end
178+
end; Y
179+
end
180+
([23, 76, 159], Sparsity(3, 3, [1, 2, 3, 2, 3, 3], [1, 1, 1, 2, 2, 3]))
181+
182+
julia> sspy(ans[2])
183+
Sparsity Pattern
184+
┌─────┐
185+
1 │⠀⠄⠀⠀⠀│ > 0
186+
3 │⠀⠅⠨⠠⠀│ < 0
187+
└─────┘
188+
1 3
189+
nz = 6
190+
191+
julia> sparsity!(f, zeros(Int, 3,3), [23,53,83])
192+
([23, 53, 83], Sparsity(9, 3, [2, 5, 8], [1, 2, 3]))
193+
194+
julia> sspy(ans[2])
195+
Sparsity Pattern
196+
┌─────┐
197+
1 │⠀⠄⠀⠀⠀│ > 0
198+
│⠀⠀⠠⠀⠀│ < 0
199+
9 │⠀⠀⠀⠐⠀│
200+
└─────┘
201+
1 3
202+
nz = 3
203+
=#

0 commit comments

Comments
 (0)