@@ -10,7 +10,7 @@ import SQLite3
10
10
#endif
11
11
12
12
extension Connection {
13
- typealias Aggregate = @convention ( block) ( Int , OpaquePointer ? , Int32 , UnsafeMutablePointer < OpaquePointer ? > ? ) -> Void
13
+ private typealias Aggregate = @convention ( block) ( Int , OpaquePointer ? , Int32 , UnsafeMutablePointer < OpaquePointer ? > ? ) -> Void
14
14
15
15
/// Creates or redefines a custom SQL aggregate.
16
16
///
@@ -41,7 +41,7 @@ extension Connection {
41
41
/// each aggregation group. The block should return an
42
42
/// UnsafeMutablePointer to the fresh state variable.
43
43
public func createAggregation< T> (
44
- _ aggregate : String ,
44
+ _ functionName : String ,
45
45
argumentCount: UInt ? = nil ,
46
46
deterministic: Bool = false ,
47
47
step: @escaping ( [ Binding ? ] , UnsafeMutablePointer < T > ) -> Void ,
@@ -50,11 +50,14 @@ extension Connection {
50
50
51
51
let argc = argumentCount. map { Int ( $0) } ?? - 1
52
52
let box : Aggregate = { ( stepFlag: Int , context: OpaquePointer ? , argc: Int32 , argv: UnsafeMutablePointer < OpaquePointer ? > ? ) in
53
- let ptr = sqlite3_aggregate_context ( context, 64 ) ! // needs to be at least as large as uintptr_t; better way to do this?
54
- let mutablePointer = ptr. assumingMemoryBound ( to: UnsafeMutableRawPointer . self)
53
+ guard let aggregateContext = sqlite3_aggregate_context ( context, Int32 ( MemoryLayout< UnsafeMutablePointer< Int64>>. size) ) else {
54
+ fatalError ( " Could not get aggregate context " )
55
+ }
56
+
57
+ let mutablePointer = aggregateContext. assumingMemoryBound ( to: UnsafeMutableRawPointer . self)
55
58
if stepFlag > 0 {
56
59
let arguments = getArguments ( argc: argc, argv: argv)
57
- if ptr . assumingMemoryBound ( to: Int64 . self) . pointee == 0 {
60
+ if aggregateContext . assumingMemoryBound ( to: Int64 . self) . pointee == 0 {
58
61
let value = state ( )
59
62
mutablePointer. pointee = UnsafeMutableRawPointer ( mutating: value)
60
63
}
@@ -65,29 +68,89 @@ extension Connection {
65
68
}
66
69
}
67
70
68
- var flags = SQLITE_UTF8
69
- if deterministic {
70
- flags |= SQLITE_DETERMINISTIC
71
+ func xStep( context: OpaquePointer ? , argc: Int32 , value: UnsafeMutablePointer < OpaquePointer ? > ? ) {
72
+ unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self) ( 1 , context, argc, value)
73
+ }
74
+
75
+ func xFinal( context: OpaquePointer ? ) {
76
+ unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self) ( 0 , context, 0 , nil )
71
77
}
72
78
79
+ let flags = SQLITE_UTF8 | ( deterministic ? SQLITE_DETERMINISTIC : 0 )
73
80
sqlite3_create_function_v2 (
74
- handle,
75
- aggregate,
76
- Int32 ( argc) ,
77
- flags,
78
- unsafeBitCast ( box, to: UnsafeMutableRawPointer . self) ,
79
- nil , { context, argc, value in
80
- let function = unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self)
81
- function ( 1 , context, argc, value)
82
- } , { context in
83
- let function = unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self)
84
- function ( 0 , context, 0 , nil )
85
- } ,
86
- nil
81
+ handle,
82
+ functionName,
83
+ Int32 ( argc) ,
84
+ flags,
85
+ /* pApp */ unsafeBitCast ( box, to: UnsafeMutableRawPointer . self) ,
86
+ /* xFunc */ nil , xStep, xFinal, /* xDestroy */ nil
87
87
)
88
- if aggregations [ aggregate ] == nil {
89
- aggregations [ aggregate ] = [ : ]
88
+ if functions [ functionName ] == nil {
89
+ functions [ functionName ] = [ : ]
90
90
}
91
- aggregations [ aggregate ] ? [ argc] = box
91
+ functions [ functionName ] ? [ argc] = box
92
92
}
93
+
94
+ func createAggregation< T: AnyObject > (
95
+ _ aggregate: String ,
96
+ argumentCount: UInt ? = nil ,
97
+ deterministic: Bool = false ,
98
+ initialValue: T ,
99
+ reduce: @escaping ( T , [ Binding ? ] ) -> T ,
100
+ result: @escaping ( T ) -> Binding ?
101
+ ) {
102
+ let step : ( [ Binding ? ] , UnsafeMutablePointer < UnsafeMutableRawPointer > ) -> Void = { ( bindings, ptr) in
103
+ let pointer = ptr. pointee. assumingMemoryBound ( to: T . self)
104
+ let current = Unmanaged < T > . fromOpaque ( pointer) . takeRetainedValue ( )
105
+ let next = reduce ( current, bindings)
106
+ ptr. pointee = Unmanaged . passRetained ( next) . toOpaque ( )
107
+ }
108
+
109
+ let final : ( UnsafeMutablePointer < UnsafeMutableRawPointer > ) -> Binding ? = { ( ptr) in
110
+ let pointer = ptr. pointee. assumingMemoryBound ( to: T . self)
111
+ let obj = Unmanaged < T > . fromOpaque ( pointer) . takeRetainedValue ( )
112
+ let value = result ( obj)
113
+ ptr. deallocate ( )
114
+ return value
115
+ }
116
+
117
+ let state : ( ) -> UnsafeMutablePointer < UnsafeMutableRawPointer > = {
118
+ let pointer = UnsafeMutablePointer< UnsafeMutableRawPointer> . allocate( capacity: 1 )
119
+ pointer. pointee = Unmanaged . passRetained ( initialValue) . toOpaque ( )
120
+ return pointer
121
+ }
122
+
123
+ createAggregation ( aggregate, step: step, final: final, state: state)
124
+ }
125
+
126
+ func createAggregation< T> (
127
+ _ aggregate: String ,
128
+ argumentCount: UInt ? = nil ,
129
+ deterministic: Bool = false ,
130
+ initialValue: T ,
131
+ reduce: @escaping ( T , [ Binding ? ] ) -> T ,
132
+ result: @escaping ( T ) -> Binding ?
133
+ ) {
134
+
135
+ let step : ( [ Binding ? ] , UnsafeMutablePointer < T > ) -> Void = { ( bindings, pointer) in
136
+ let current = pointer. pointee
137
+ let next = reduce ( current, bindings)
138
+ pointer. pointee = next
139
+ }
140
+
141
+ let final : ( UnsafeMutablePointer < T > ) -> Binding ? = { pointer in
142
+ let value = result ( pointer. pointee)
143
+ pointer. deallocate ( )
144
+ return value
145
+ }
146
+
147
+ let state : ( ) -> UnsafeMutablePointer < T > = {
148
+ let pointer = UnsafeMutablePointer< T> . allocate( capacity: 1 )
149
+ pointer. initialize ( to: initialValue)
150
+ return pointer
151
+ }
152
+
153
+ createAggregation ( aggregate, step: step, final: final, state: state)
154
+ }
155
+
93
156
}
0 commit comments