Skip to content

Commit 894ae76

Browse files
committed
WIP
1 parent 19fd8a5 commit 894ae76

File tree

3 files changed

+100
-103
lines changed

3 files changed

+100
-103
lines changed

Sources/SQLite/Core/Connection+Aggregation.swift

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import SQLite3
1010
#endif
1111

1212
extension Connection {
13-
typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
13+
private typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
1414

1515
/// Creates or redefines a custom SQL aggregate.
1616
///
@@ -41,7 +41,7 @@ extension Connection {
4141
/// each aggregation group. The block should return an
4242
/// UnsafeMutablePointer to the fresh state variable.
4343
public func createAggregation<T>(
44-
_ aggregate: String,
44+
_ functionName: String,
4545
argumentCount: UInt? = nil,
4646
deterministic: Bool = false,
4747
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> Void,
@@ -50,11 +50,14 @@ extension Connection {
5050

5151
let argc = argumentCount.map { Int($0) } ?? -1
5252
let box: Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer<OpaquePointer?>?) in
53-
let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this?
54-
let mutablePointer = ptr.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
53+
guard let aggregateContext = sqlite3_aggregate_context(context, Int32(MemoryLayout<UnsafeMutablePointer<Int64>>.size)) else {
54+
fatalError("Could not get aggregate context")
55+
}
56+
57+
let mutablePointer = aggregateContext.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
5558
if stepFlag > 0 {
5659
let arguments = getArguments(argc: argc, argv: argv)
57-
if ptr.assumingMemoryBound(to: Int64.self).pointee == 0 {
60+
if aggregateContext.assumingMemoryBound(to: Int64.self).pointee == 0 {
5861
let value = state()
5962
mutablePointer.pointee = UnsafeMutableRawPointer(mutating: value)
6063
}
@@ -65,29 +68,89 @@ extension Connection {
6568
}
6669
}
6770

68-
var flags = SQLITE_UTF8
69-
if deterministic {
70-
flags |= SQLITE_DETERMINISTIC
71+
func xStep(context: OpaquePointer?, argc: Int32, value: UnsafeMutablePointer<OpaquePointer?>?) {
72+
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(1, context, argc, value)
73+
}
74+
75+
func xFinal(context: OpaquePointer?) {
76+
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(0, context, 0, nil)
7177
}
7278

79+
let flags = SQLITE_UTF8 | (deterministic ? SQLITE_DETERMINISTIC : 0)
7380
sqlite3_create_function_v2(
74-
handle,
75-
aggregate,
76-
Int32(argc),
77-
flags,
78-
unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
79-
nil, { context, argc, value in
80-
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
81-
function(1, context, argc, value)
82-
}, { context in
83-
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
84-
function(0, context, 0, nil)
85-
},
86-
nil
81+
handle,
82+
functionName,
83+
Int32(argc),
84+
flags,
85+
/* pApp */ unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
86+
/* xFunc */ nil, xStep, xFinal, /* xDestroy */ nil
8787
)
88-
if aggregations[aggregate] == nil {
89-
aggregations[aggregate] = [:]
88+
if functions[functionName] == nil {
89+
functions[functionName] = [:]
9090
}
91-
aggregations[aggregate]?[argc] = box
91+
functions[functionName]?[argc] = box
9292
}
93+
94+
func createAggregation<T: AnyObject>(
95+
_ aggregate: String,
96+
argumentCount: UInt? = nil,
97+
deterministic: Bool = false,
98+
initialValue: T,
99+
reduce: @escaping (T, [Binding?]) -> T,
100+
result: @escaping (T) -> Binding?
101+
) {
102+
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Void = { (bindings, ptr) in
103+
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
104+
let current = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
105+
let next = reduce(current, bindings)
106+
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
107+
}
108+
109+
let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { (ptr) in
110+
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
111+
let obj = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
112+
let value = result(obj)
113+
ptr.deallocate()
114+
return value
115+
}
116+
117+
let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
118+
let pointer = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
119+
pointer.pointee = Unmanaged.passRetained(initialValue).toOpaque()
120+
return pointer
121+
}
122+
123+
createAggregation(aggregate, step: step, final: final, state: state)
124+
}
125+
126+
func createAggregation<T>(
127+
_ aggregate: String,
128+
argumentCount: UInt? = nil,
129+
deterministic: Bool = false,
130+
initialValue: T,
131+
reduce: @escaping (T, [Binding?]) -> T,
132+
result: @escaping (T) -> Binding?
133+
) {
134+
135+
let step: ([Binding?], UnsafeMutablePointer<T>) -> Void = { (bindings, pointer) in
136+
let current = pointer.pointee
137+
let next = reduce(current, bindings)
138+
pointer.pointee = next
139+
}
140+
141+
let final: (UnsafeMutablePointer<T>) -> Binding? = { pointer in
142+
let value = result(pointer.pointee)
143+
pointer.deallocate()
144+
return value
145+
}
146+
147+
let state: () -> UnsafeMutablePointer<T> = {
148+
let pointer = UnsafeMutablePointer<T>.allocate(capacity: 1)
149+
pointer.initialize(to: initialValue)
150+
return pointer
151+
}
152+
153+
createAggregation(aggregate, step: step, final: final, state: state)
154+
}
155+
93156
}

Sources/SQLite/Core/Connection.swift

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -585,38 +585,37 @@ public final class Connection {
585585
/// - block: A block of code to run when the function is called. The block
586586
/// is called with an array of raw SQL values mapped to the function’s
587587
/// parameters and should return a raw SQL value (or nil).
588-
public func createFunction(_ function: String, argumentCount: UInt? = nil, deterministic: Bool = false,
588+
public func createFunction(_ functionName: String,
589+
argumentCount: UInt? = nil,
590+
deterministic: Bool = false,
589591
_ block: @escaping (_ args: [Binding?]) -> Binding?) {
590592
let argc = argumentCount.map { Int($0) } ?? -1
591593
let box: Function = { context, argc, argv in
592594
set(result: block(getArguments(argc: argc, argv: argv)), on: context)
593595
}
594-
var flags = SQLITE_UTF8
595-
if deterministic {
596-
flags |= SQLITE_DETERMINISTIC
596+
func xFunc(context: OpaquePointer?, argc: Int32, value: UnsafeMutablePointer<OpaquePointer?>?) {
597+
unsafeBitCast(sqlite3_user_data(context), to: Function.self)(context, argc, value)
597598
}
599+
let flags = SQLITE_UTF8 | (deterministic ? SQLITE_DETERMINISTIC : 0)
598600
let resultCode = sqlite3_create_function_v2(handle,
599-
function,
601+
functionName,
600602
Int32(argc),
601603
flags,
602-
unsafeBitCast(box, to: UnsafeMutableRawPointer.self), { context, argc, value in
603-
let function = unsafeBitCast(sqlite3_user_data(context), to: Function.self)
604-
function(context, argc, value)
605-
}, nil, nil, nil)
604+
/* pApp */ unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
605+
xFunc, /*xStep*/ nil, /*xFinal*/ nil, /*xDestroy*/ nil)
606606

607607
if let result = Result(errorCode: resultCode, connection: self, statement: nil) {
608608
fatalError("Error creating function: \(result)")
609609
}
610610

611-
if functions[function] == nil {
612-
functions[function] = [:] // fails on Linux, https://github.com/stephencelis/SQLite.swift/issues/1071
611+
if functions[functionName] == nil {
612+
functions[functionName] = [:] // fails on Linux, https://github.com/stephencelis/SQLite.swift/issues/1071
613613
}
614-
functions[function]?[argc] = box
614+
functions[functionName]?[argc] = box
615615
}
616616

617617
fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
618-
fileprivate var functions = [String: [Int: Function]]()
619-
var aggregations = [String: [Int: Aggregate]]()
618+
var functions = [String: [Int: Any]]()
620619

621620
/// Defines a new collating sequence.
622621
///

Sources/SQLite/Typed/CustomFunctions.swift

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -159,69 +159,4 @@ public extension Connection {
159159
}
160160
}
161161

162-
// MARK: -
163-
164-
func createAggregation<T: AnyObject>(
165-
_ aggregate: String,
166-
argumentCount: UInt? = nil,
167-
deterministic: Bool = false,
168-
initialValue: T,
169-
reduce: @escaping (T, [Binding?]) -> T,
170-
result: @escaping (T) -> Binding?
171-
) {
172-
173-
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Void = { (bindings, ptr) in
174-
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
175-
let current = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
176-
let next = reduce(current, bindings)
177-
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
178-
}
179-
180-
let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { (ptr) in
181-
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
182-
let obj = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
183-
let value = result(obj)
184-
ptr.deallocate()
185-
return value
186-
}
187-
188-
let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
189-
let pointer = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
190-
pointer.pointee = Unmanaged.passRetained(initialValue).toOpaque()
191-
return pointer
192-
}
193-
194-
createAggregation(aggregate, step: step, final: final, state: state)
195-
}
196-
197-
func createAggregation<T>(
198-
_ aggregate: String,
199-
argumentCount: UInt? = nil,
200-
deterministic: Bool = false,
201-
initialValue: T,
202-
reduce: @escaping (T, [Binding?]) -> T,
203-
result: @escaping (T) -> Binding?
204-
) {
205-
206-
let step: ([Binding?], UnsafeMutablePointer<T>) -> Void = { (bindings, pointer) in
207-
let current = pointer.pointee
208-
let next = reduce(current, bindings)
209-
pointer.pointee = next
210-
}
211-
212-
let final: (UnsafeMutablePointer<T>) -> Binding? = { pointer in
213-
let value = result(pointer.pointee)
214-
pointer.deallocate()
215-
return value
216-
}
217-
218-
let state: () -> UnsafeMutablePointer<T> = {
219-
let pointer = UnsafeMutablePointer<T>.allocate(capacity: 1)
220-
pointer.initialize(to: initialValue)
221-
return pointer
222-
}
223-
224-
createAggregation(aggregate, step: step, final: final, state: state)
225-
}
226-
227162
}

0 commit comments

Comments
 (0)