Skip to content

Commit 8ec2226

Browse files
author
Marc Rasi
committed
Make Values.TangentVector == VectorValues
1 parent c8ebec6 commit 8ec2226

File tree

7 files changed

+112
-57
lines changed

7 files changed

+112
-57
lines changed

Examples/Pose2SLAMG2O/main.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct G2OFactorGraph: G2OReader {
3636
var graph: NonlinearFactorGraph = NonlinearFactorGraph()
3737

3838
public mutating func addInitialGuess(index: Int, pose: Pose2) {
39-
initialGuess.insert(index, AnyDifferentiable(pose))
39+
initialGuess.insert(index, pose)
4040
}
4141

4242
public mutating func addMeasurement(frameIndex: Int, measuredIndex: Int, pose: Pose2) {
@@ -72,11 +72,7 @@ func main() {
7272
dx.insert(i, Vector(zeros: 3))
7373
}
7474
optimizer.optimize(gfg: gfg, initial: &dx)
75-
for i in 0..<val.count {
76-
var p = val[i].baseAs(Pose2.self)
77-
p.move(along: Vector3(dx[i]))
78-
val[i] = AnyDifferentiable(p)
79-
}
75+
val.move(along: dx)
8076
print("Current error: \(problem.graph.error(val))")
8177
}
8278
print("Final error: \(problem.graph.error(val))")

Sources/SwiftFusion/Inference/BetweenFactor.swift

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public struct BetweenFactor<T: LieGroup>: NonlinearFactor where T.TangentVector:
6767
/// Returns the `error` of the factor.
6868
@differentiable(wrt: values)
6969
public func error(_ values: Values) -> Double {
70-
let actual = values[key1].baseAs(T.self).inverse() * values[key2].baseAs(T.self)
70+
let actual = values[key1, as: T.self].inverse() * values[key2, as: T.self]
7171
let error = difference.local(actual)
7272
// TODO: It would be faster to call `error.squaredNorm` because then we don't have to pay
7373
// the cost of a conversion to `Vector`. To do this, we need a protocol
@@ -78,19 +78,13 @@ public struct BetweenFactor<T: LieGroup>: NonlinearFactor where T.TangentVector:
7878
@differentiable(wrt: values)
7979
public func errorVector(_ values: Values) -> T.Coordinate.LocalCoordinate {
8080
let error = difference.local(
81-
values[key1].baseAs(T.self).inverse() * values[key2].baseAs(T.self)
81+
values[key1, as: T.self].inverse() * values[key2, as: T.self]
8282
)
8383

8484
return error
8585
}
8686

8787
public func linearize(_ values: Values) -> JacobianFactor {
88-
let j = jacobian(of: self.errorVector, at: values)
89-
90-
let j1 = Matrix(stacking: (0..<j.count).map { i in (j[i]._values[values._indices[key1]!].base as! T.TangentVector).vector } )
91-
let j2 = Matrix(stacking: (0..<j.count).map { i in (j[i]._values[values._indices[key2]!].base as! T.TangentVector).vector } )
92-
93-
// TODO: remove this negative sign
94-
return JacobianFactor(keys, [j1, j2], errorVector(values).vector.scaled(by: -1))
88+
return JacobianFactor(of: self.errorVector, at: values)
9589
}
9690
}

Sources/SwiftFusion/Inference/JacobianFactor.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,40 @@ public struct JacobianFactor: LinearFactor {
104104
return result
105105
}
106106
}
107+
108+
extension JacobianFactor {
109+
/// Creates a `JacobianFactor` by linearizing the error function `f` at `p`.
110+
public init<R: VectorConvertible & TangentStandardBasis>(
111+
of f: @differentiable (Values) -> R,
112+
at p: Values
113+
) {
114+
// Compute the rows of the jacobian.
115+
let (value, pb) = valueWithPullback(at: p, in: f)
116+
let rows = R.tangentStandardBasis.map { pb($0) }
117+
118+
// Construct empty matrices with the correct shape.
119+
assert(rows.count > 0)
120+
var matrices = Dictionary<Int, Matrix>(uniqueKeysWithValues: rows[0].keys.map { key in
121+
let row = rows[0][key]
122+
var matrix = Matrix([], rowCount: 0, columnCount: row.dimension)
123+
matrix.reserveCapacity(rows.count * row.dimension)
124+
return (key, matrix)
125+
})
126+
127+
// Fill in the matrix entries.
128+
for row in rows {
129+
for key in row.keys {
130+
matrices[key]!.append(row: row[key])
131+
}
132+
}
133+
134+
// Return the jacobian factor with the matrices and value.
135+
let orderedKeys = Array(matrices.keys)
136+
self = JacobianFactor(
137+
orderedKeys,
138+
orderedKeys.map { matrices[$0]! },
139+
// TODO: remove this negative sign
140+
value.vector.scaled(by: -1)
141+
)
142+
}
143+
}

Sources/SwiftFusion/Inference/PriorFactor.swift

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public struct PriorFactor<T: LieGroup>: NonlinearFactor where T.TangentVector: V
4141
/// Returns the `error` of the factor.
4242
@differentiable(wrt: values)
4343
public func error(_ values: Values) -> Double {
44-
let error = difference.local(values[keys[0]].baseAs(T.self))
44+
let error = difference.local(values[keys[0], as: T.self])
4545
// TODO: It would be faster to call `error.squaredNorm` because then we don't have to pay
4646
// the cost of a conversion to `Vector`. To do this, we need a protocol
4747
// with a `squaredNorm` requirement.
@@ -50,14 +50,12 @@ public struct PriorFactor<T: LieGroup>: NonlinearFactor where T.TangentVector: V
5050

5151
@differentiable(wrt: values)
5252
public func errorVector(_ values: Values) -> T.Coordinate.LocalCoordinate {
53-
let val = values[keys[0]].baseAs(T.self)
53+
let val = values[keys[0], as: T.self]
5454
let error = difference.local(val)
5555
return error
5656
}
57-
57+
5858
public func linearize(_ values: Values) -> JacobianFactor {
59-
let j = jacobian(of: self.errorVector, at: values)
60-
let j1 = Matrix(stacking: (0..<j.count).map { i in (j[i]._values[values._indices[keys[0]]!].base as! T.TangentVector).vector } )
61-
return JacobianFactor(keys, [j1], errorVector(values).vector.scaled(by: -1))
59+
return JacobianFactor(of: self.errorVector, at: values)
6260
}
6361
}

Sources/SwiftFusion/Inference/Values.swift

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,64 @@ public struct Values: Differentiable & KeyPathIterable {
3434
public var count: Int {
3535
return _values.count
3636
}
37+
38+
/// MARK: - Differentiable conformance and related properties and helpers.
39+
40+
/// The product space of the tangent spaces of all the values.
41+
public typealias TangentVector = VectorValues
42+
43+
/// `makeTangentVector[i]` produces a type-erased tangent vector for `values[i]`.
44+
private var makeTangentVector: [(Vector) -> AnyDerivative] = []
45+
46+
public mutating func move(along direction: VectorValues) {
47+
for key in direction.keys {
48+
let index = self._indices[key]!
49+
self._values[index].move(along: makeTangentVector[index](direction[key]))
50+
}
51+
}
3752

38-
/// The subscript operator, with some indirection
39-
/// Should be replaced after Dictionary is in
53+
/// MARK: - Value manipulation methods.
54+
55+
/// Access the value at `key`, with type `type`.
56+
///
57+
/// Precondition: The value actually has type `type`.
4058
@differentiable
41-
public subscript(key: Int) -> AnyDifferentiable {
59+
public subscript<T: Differentiable>(key: Int, as type: T.Type) -> T
60+
where T.TangentVector: VectorConvertible
61+
{
4262
get {
43-
_values[_indices[key]!]
63+
return _values[_indices[key]!].baseAs(type)
4464
}
45-
set(newVal) {
46-
_values[_indices[key]!] = newVal
65+
set(newValue) {
66+
_values[_indices[key]!] = AnyDifferentiable(newValue)
4767
}
4868
}
49-
69+
70+
@derivative(of: subscript)
71+
@usableFromInline
72+
func vjpSubscript<T: Differentiable>(key: Int, as type: T.Type)
73+
-> (value: T, pullback: (T.TangentVector) -> VectorValues)
74+
where T.TangentVector: VectorConvertible
75+
{
76+
return (
77+
self._values[self._indices[key]!].baseAs(type),
78+
{ (t: T.TangentVector) in
79+
var vectorValues = VectorValues()
80+
vectorValues.insert(key, t.vector)
81+
return vectorValues
82+
}
83+
)
84+
}
85+
5086
/// Insert a key value pair
51-
public mutating func insert(_ key: Int, _ val: AnyDifferentiable) {
87+
public mutating func insert<T: Differentiable>(_ key: Int, _ val: T)
88+
where T.TangentVector: VectorConvertible
89+
{
5290
assert(_indices[key] == nil)
5391

5492
self._indices[key] = self._values.count
55-
self._values.append(val)
93+
self._values.append(AnyDifferentiable(val))
94+
self.makeTangentVector.append({ AnyDerivative(T.TangentVector($0)) })
5695
}
5796

5897
}

Tests/SwiftFusionTests/Geometry/Pose3Tests.swift

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ final class Pose3Tests: XCTestCase {
4545
let prior_factor = PriorFactor(0, t1)
4646

4747
var vals = Values()
48-
vals.insert(0, AnyDifferentiable(t1)) // should be identity matrix
48+
vals.insert(0, t1) // should be identity matrix
4949
// Change this to t2, still zero in upper left block
5050

5151
let actual = prior_factor.linearize(vals).jacobians[0]
@@ -74,7 +74,7 @@ final class Pose3Tests: XCTestCase {
7474
let gti = Vector3(radius * cos(theta), radius * sin(theta), 0)
7575
let oRi = Rot3.fromTangent(Vector3(0, 0, -theta)) // negative yaw goes counterclockwise, with Z down !
7676
let gTi = Pose3(gRo * oRi, gti)
77-
values.insert(key, AnyDifferentiable(gTi))
77+
values.insert(key, gTi)
7878
theta = theta + dtheta
7979
}
8080
return values
@@ -83,8 +83,8 @@ final class Pose3Tests: XCTestCase {
8383
func testGtsamPose3SLAMExample() {
8484
// Create a hexagon of poses
8585
let hexagon = circlePose3(numPoses: 6, radius: 1.0)
86-
let p0 = hexagon[0].baseAs(Pose3.self)
87-
let p1 = hexagon[1].baseAs(Pose3.self)
86+
let p0 = hexagon[0, as: Pose3.self]
87+
let p1 = hexagon[1, as: Pose3.self]
8888

8989
// create a Pose graph with one equality constraint and one measurement
9090
var fg = NonlinearFactorGraph()
@@ -101,12 +101,12 @@ final class Pose3Tests: XCTestCase {
101101
// Create initial config
102102
var val = Values()
103103
let s = 0.10
104-
val.insert(0, AnyDifferentiable(p0))
105-
val.insert(1, AnyDifferentiable(hexagon[1].baseAs(Pose3.self).global(Vector6(s * Tensor<Double>(randomNormal: [6])))))
106-
val.insert(2, AnyDifferentiable(hexagon[2].baseAs(Pose3.self).global(Vector6(s * Tensor<Double>(randomNormal: [6])))))
107-
val.insert(3, AnyDifferentiable(hexagon[3].baseAs(Pose3.self).global(Vector6(s * Tensor<Double>(randomNormal: [6])))))
108-
val.insert(4, AnyDifferentiable(hexagon[4].baseAs(Pose3.self).global(Vector6(s * Tensor<Double>(randomNormal: [6])))))
109-
val.insert(5, AnyDifferentiable(hexagon[5].baseAs(Pose3.self).global(Vector6(s * Tensor<Double>(randomNormal: [6])))))
104+
val.insert(0, p0)
105+
val.insert(1, hexagon[1, as: Pose3.self].global(Vector6(s * Tensor<Double>(randomNormal: [6]))))
106+
val.insert(2, hexagon[2, as: Pose3.self].global(Vector6(s * Tensor<Double>(randomNormal: [6]))))
107+
val.insert(3, hexagon[3, as: Pose3.self].global(Vector6(s * Tensor<Double>(randomNormal: [6]))))
108+
val.insert(4, hexagon[4, as: Pose3.self].global(Vector6(s * Tensor<Double>(randomNormal: [6]))))
109+
val.insert(5, hexagon[5, as: Pose3.self].global(Vector6(s * Tensor<Double>(randomNormal: [6]))))
110110

111111
// optimize
112112
for _ in 0..<16 {
@@ -122,15 +122,10 @@ final class Pose3Tests: XCTestCase {
122122

123123
optimizer.optimize(gfg: gfg, initial: &dx)
124124

125-
126-
for i in 0..<6 {
127-
var p = val[i].baseAs(Pose3.self)
128-
p.move(along: Vector6(dx[i]))
129-
val[i] = AnyDifferentiable(p)
130-
}
125+
val.move(along: dx)
131126
}
132127

133-
let pose_1 = val[1].baseAs(Pose3.self)
128+
let pose_1 = val[1, as: Pose3.self]
134129
assertAllKeyPathEqual(pose_1, p1, accuracy: 1e-2)
135130
}
136131
}

Tests/SwiftFusionTests/Inference/NonlinearFactorGraphTests.swift

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ final class NonlinearFactorGraphTests: XCTestCase {
1212
fg += bf1
1313

1414
var val = Values()
15-
val.insert(0, AnyDifferentiable(Pose2(1.0, 1.0, 0.0)))
16-
val.insert(1, AnyDifferentiable(Pose2(1.0, 1.0, .pi)))
15+
val.insert(0, Pose2(1.0, 1.0, 0.0))
16+
val.insert(1, Pose2(1.0, 1.0, .pi))
1717

1818
let gfg = fg.linearize(val)
1919

@@ -51,7 +51,7 @@ final class NonlinearFactorGraphTests: XCTestCase {
5151
var val = Values()
5252

5353
for i in 0..<5 {
54-
val.insert(i, AnyDifferentiable(map[i]))
54+
val.insert(i, map[i])
5555
}
5656

5757
for _ in 0..<3 {
@@ -67,11 +67,7 @@ final class NonlinearFactorGraphTests: XCTestCase {
6767

6868
optimizer.optimize(gfg: gfg, initial: &dx)
6969

70-
for i in 0..<5 {
71-
var p = val[i].baseAs(Pose2.self)
72-
p.move(along: Vector3(dx[i]))
73-
val[i] = AnyDifferentiable(p)
74-
}
70+
val.move(along: dx)
7571
}
7672

7773
let dumpjson = { (p: Pose2) -> String in
@@ -84,14 +80,14 @@ final class NonlinearFactorGraphTests: XCTestCase {
8480
}
8581
print("]")
8682

87-
let map_final = (0..<5).map { val[$0].baseAs(Pose2.self) }
83+
let map_final = (0..<5).map { val[$0, as: Pose2.self] }
8884
print("map = [")
8985
for v in map_final.indices {
9086
print("\(dumpjson(map_final[v]))\({ () -> String in if v == map_final.indices.endIndex - 1 { return "" } else { return "," } }())")
9187
}
9288
print("]")
9389

94-
let p5T1 = between(val[4].baseAs(Pose2.self), val[0].baseAs(Pose2.self))
90+
let p5T1 = between(val[4, as: Pose2.self], val[0, as: Pose2.self])
9591

9692
// Test condition: P_5 should be identical to P_1 (close loop)
9793
XCTAssertEqual(p5T1.t.norm, 0.0, accuracy: 1e-2)

0 commit comments

Comments
 (0)