Skip to content

Commit 451b922

Browse files
authored
Use anonymous function to register UDF and avoid name clash (#333)
1 parent 4b22cc4 commit 451b922

File tree

4 files changed

+158
-159
lines changed

4 files changed

+158
-159
lines changed

src/SQLite.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ mutable struct DB <: DBInterface.Connection
5858
file::String
5959
handle::DBHandle
6060
stmt_wrappers::WeakKeyDict{StmtWrapper,Nothing} # opened prepared statements
61+
registered_UDFs::Vector{Any} # keep registered UDFs alive and not garbage collected
6162

6263
function DB(f::AbstractString)
6364
handle_ptr = Ref{DBHandle}()
6465
f = String(isempty(f) ? f : expanduser(f))
6566
if @OK C.sqlite3_open(f, handle_ptr)
66-
db = new(f, handle_ptr[], WeakKeyDict{StmtWrapper,Nothing}())
67+
db = new(f, handle_ptr[], WeakKeyDict{StmtWrapper,Nothing}(), Any[])
6768
finalizer(_close_db!, db)
6869
return db
6970
else # error

src/UDF.jl

Lines changed: 130 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,16 @@ end
4444
sqlreturn(context, val::Bool) = sqlreturn(context, Int(val))
4545
sqlreturn(context, val) = sqlreturn(context, sqlserialize(val))
4646

47-
# Internal method for generating an SQLite scalar function from
48-
# a Julia function name
49-
function scalarfunc(func, fsym = Symbol(string(func)))
50-
# check if name defined in Base so we don't clobber Base methods
51-
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
52-
return quote
53-
#nm needs to be a symbol or expr, i.e. :sin or :(Base.sin)
54-
function $(nm)(
55-
context::Ptr{Cvoid},
56-
nargs::Cint,
57-
values::Ptr{Ptr{Cvoid}},
58-
)
59-
args = [sqlvalue(values, i) for i in 1:nargs]
60-
ret = $(func)(args...)
61-
sqlreturn(context, ret)
62-
nothing
63-
end
64-
return $(nm)
65-
end
66-
end
67-
function scalarfunc(expr::Expr)
68-
f = eval(expr)
69-
return scalarfunc(f)
47+
function wrap_scalarfunc(
48+
func,
49+
context::Ptr{Cvoid},
50+
nargs::Cint,
51+
values::Ptr{Ptr{Cvoid}},
52+
)
53+
args = [sqlvalue(values, i) for i in 1:nargs]
54+
ret = func(args...)
55+
sqlreturn(context, ret)
56+
nothing
7057
end
7158

7259
# convert a byteptr to an int, assumes little-endian
@@ -82,135 +69,116 @@ function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
8269
return htol(s)
8370
end
8471

85-
function stepfunc(init, func, fsym = Symbol(string(func) * "_step"))
86-
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
87-
return quote
88-
function $(nm)(
89-
context::Ptr{Cvoid},
90-
nargs::Cint,
91-
values::Ptr{Ptr{Cvoid}},
92-
)
93-
args = [sqlvalue(values, i) for i in 1:nargs]
94-
95-
intsize = sizeof(Int)
96-
ptrsize = sizeof(Ptr)
97-
acsize = intsize + ptrsize
98-
acptr = convert(
99-
Ptr{UInt8},
100-
C.sqlite3_aggregate_context(context, acsize),
101-
)
102-
103-
# acptr will be zeroed-out if this is the first iteration
104-
ret = ccall(
105-
:memcmp,
106-
Cint,
107-
(Ptr{UInt8}, Ptr{UInt8}, Cuint),
108-
zeros(UInt8, acsize),
109-
acptr,
110-
acsize,
111-
)
112-
if ret == 0
113-
acval = $(init)
114-
valsize = 256
115-
# avoid the garbage collector using malloc
116-
valptr = convert(Ptr{UInt8}, Libc.malloc(valsize))
117-
valptr == C_NULL && throw(SQLiteException("memory error"))
118-
else
119-
# size of serialized value is first sizeof(Int) bytes
120-
valsize = bytestoint(acptr, 1, intsize)
121-
# ptr to serialized value is last sizeof(Ptr) bytes
122-
valptr = reinterpret(
123-
Ptr{UInt8},
124-
bytestoint(acptr, intsize + 1, ptrsize),
125-
)
126-
# deserialize the value pointed to by valptr
127-
acvalbuf = zeros(UInt8, valsize)
128-
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
129-
acval = sqldeserialize(acvalbuf)
130-
end
131-
132-
local funcret
133-
try
134-
funcret = sqlserialize($(func)(acval, args...))
135-
catch
136-
Libc.free(valptr)
137-
rethrow()
138-
end
139-
140-
newsize = sizeof(funcret)
141-
if newsize > valsize
142-
# TODO: increase this in a cleverer way?
143-
tmp = convert(Ptr{UInt8}, Libc.realloc(valptr, newsize))
144-
if tmp == C_NULL
145-
Libc.free(valptr)
146-
throw(SQLiteException("memory error"))
147-
else
148-
valptr = tmp
149-
end
150-
end
151-
# copy serialized return value
152-
unsafe_copyto!(valptr, pointer(funcret), newsize)
153-
154-
# copy the size of the serialized value
155-
unsafe_copyto!(
156-
acptr,
157-
pointer(reinterpret(UInt8, [newsize])),
158-
intsize,
159-
)
160-
# copy the address of the pointer to the serialized value
161-
valarr = reinterpret(UInt8, [valptr])
162-
for i in 1:length(valarr)
163-
unsafe_store!(acptr, valarr[i], intsize + i)
164-
end
165-
nothing
72+
function wrap_stepfunc(
73+
init,
74+
func,
75+
context::Ptr{Cvoid},
76+
nargs::Cint,
77+
values::Ptr{Ptr{Cvoid}},
78+
)
79+
args = [sqlvalue(values, i) for i in 1:nargs]
80+
81+
intsize = sizeof(Int)
82+
ptrsize = sizeof(Ptr)
83+
acsize = intsize + ptrsize
84+
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, acsize))
85+
86+
# acptr will be zeroed-out if this is the first iteration
87+
ret = ccall(
88+
:memcmp,
89+
Cint,
90+
(Ptr{UInt8}, Ptr{UInt8}, Cuint),
91+
zeros(UInt8, acsize),
92+
acptr,
93+
acsize,
94+
)
95+
if ret == 0
96+
acval = init
97+
valsize = 256
98+
# avoid the garbage collector using malloc
99+
valptr = convert(Ptr{UInt8}, Libc.malloc(valsize))
100+
valptr == C_NULL && throw(SQLiteException("memory error"))
101+
else
102+
# size of serialized value is first sizeof(Int) bytes
103+
valsize = bytestoint(acptr, 1, intsize)
104+
# ptr to serialized value is last sizeof(Ptr) bytes
105+
valptr =
106+
reinterpret(Ptr{UInt8}, bytestoint(acptr, intsize + 1, ptrsize))
107+
# deserialize the value pointed to by valptr
108+
acvalbuf = zeros(UInt8, valsize)
109+
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
110+
acval = sqldeserialize(acvalbuf)
111+
end
112+
113+
local funcret
114+
try
115+
funcret = sqlserialize(func(acval, args...))
116+
catch
117+
Libc.free(valptr)
118+
rethrow()
119+
end
120+
121+
newsize = sizeof(funcret)
122+
if newsize > valsize
123+
# TODO: increase this in a cleverer way?
124+
tmp = convert(Ptr{UInt8}, Libc.realloc(valptr, newsize))
125+
if tmp == C_NULL
126+
Libc.free(valptr)
127+
throw(SQLiteException("memory error"))
128+
else
129+
valptr = tmp
166130
end
167-
return $(nm)
168131
end
132+
# copy serialized return value
133+
unsafe_copyto!(valptr, pointer(funcret), newsize)
134+
135+
# copy the size of the serialized value
136+
unsafe_copyto!(acptr, pointer(reinterpret(UInt8, [newsize])), intsize)
137+
# copy the address of the pointer to the serialized value
138+
valarr = reinterpret(UInt8, [valptr])
139+
for i in 1:length(valarr)
140+
unsafe_store!(acptr, valarr[i], intsize + i)
141+
end
142+
nothing
169143
end
170144

171-
function finalfunc(init, func, fsym = Symbol(string(func) * "_final"))
172-
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
173-
return quote
174-
function $(nm)(
175-
context::Ptr{Cvoid},
176-
nargs::Cint,
177-
values::Ptr{Ptr{Cvoid}},
178-
)
179-
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))
180-
181-
# step function wasn't run
182-
if acptr == C_NULL
183-
sqlreturn(context, $(init))
184-
else
185-
intsize = sizeof(Int)
186-
ptrsize = sizeof(Ptr)
187-
acsize = intsize + ptrsize
188-
189-
# load size
190-
valsize = bytestoint(acptr, 1, intsize)
191-
# load ptr
192-
valptr = reinterpret(
193-
Ptr{UInt8},
194-
bytestoint(acptr, intsize + 1, ptrsize),
195-
)
196-
197-
# load value
198-
acvalbuf = zeros(UInt8, valsize)
199-
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
200-
acval = sqldeserialize(acvalbuf)
201-
202-
local ret
203-
try
204-
ret = $(func)(acval)
205-
finally
206-
Libc.free(valptr)
207-
end
208-
sqlreturn(context, ret)
209-
end
210-
nothing
145+
function wrap_finalfunc(
146+
init,
147+
func,
148+
context::Ptr{Cvoid},
149+
nargs::Cint,
150+
values::Ptr{Ptr{Cvoid}},
151+
)
152+
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))
153+
154+
# step function wasn't run
155+
if acptr == C_NULL
156+
sqlreturn(context, init)
157+
else
158+
intsize = sizeof(Int)
159+
ptrsize = sizeof(Ptr)
160+
acsize = intsize + ptrsize
161+
162+
# load size
163+
valsize = bytestoint(acptr, 1, intsize)
164+
# load ptr
165+
valptr =
166+
reinterpret(Ptr{UInt8}, bytestoint(acptr, intsize + 1, ptrsize))
167+
168+
# load value
169+
acvalbuf = zeros(UInt8, valsize)
170+
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
171+
acval = sqldeserialize(acvalbuf)
172+
173+
local ret
174+
try
175+
ret = func(acval)
176+
finally
177+
Libc.free(valptr)
211178
end
212-
return $(nm)
179+
sqlreturn(context, ret)
213180
end
181+
nothing
214182
end
215183

216184
"""
@@ -223,6 +191,8 @@ macro register(db, func)
223191
:(register($(esc(db)), $(esc(func))))
224192
end
225193

194+
UDF_keep_alive_list = []
195+
226196
"""
227197
SQLite.register(db, func)
228198
SQLite.register(db, init, step_func, final_func; nargs=-1, name=string(step), isdeterm=true)
@@ -242,9 +212,12 @@ function register(
242212
nargs < -1 && (nargs = -1)
243213
@assert sizeof(name) <= 255 "size of function name must be <= 255"
244214

245-
f = eval(scalarfunc(func, Symbol(name)))
246-
215+
f =
216+
(context, nargs, values) ->
217+
wrap_scalarfunc(func, context, nargs, values)
247218
cfunc = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
219+
push!(db.registered_UDFs, cfunc)
220+
248221
# TODO: allow the other encodings
249222
enc = C.SQLITE_UTF8
250223
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc
@@ -263,12 +236,11 @@ function register(
263236
end
264237

265238
# as above but for aggregate functions
266-
newidentity() = @eval x -> x
267239
function register(
268240
db,
269241
init,
270242
step::Function,
271-
final::Function = newidentity();
243+
final::Function = identity;
272244
nargs::Int = -1,
273245
name::AbstractString = string(step),
274246
isdeterm::Bool = true,
@@ -277,10 +249,16 @@ function register(
277249
nargs < -1 && (nargs = -1)
278250
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"
279251

280-
s = eval(stepfunc(init, step, Base.nameof(step)))
252+
s =
253+
(context, nargs, values) ->
254+
wrap_stepfunc(init, step, context, nargs, values)
281255
cs = @cfunction($s, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
282-
f = eval(finalfunc(init, final, Base.nameof(final)))
256+
f =
257+
(context, nargs, values) ->
258+
wrap_finalfunc(init, final, context, nargs, values)
283259
cf = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
260+
push!(db.registered_UDFs, cs)
261+
push!(db.registered_UDFs, cf)
284262

285263
enc = C.SQLITE_UTF8
286264
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc

src/tables.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ end
7777
)
7878
end
7979

80-
function getvalue(q::Query{strict}, col::Int, rownumber::Int, ::Type{T}) where {strict, T}
80+
function getvalue(
81+
q::Query{strict},
82+
col::Int,
83+
rownumber::Int,
84+
::Type{T},
85+
) where {strict,T}
8186
rownumber == q.current_rownumber[] || wrongrow(rownumber)
8287
handle = _get_stmt_handle(q.stmt)
8388
t = C.sqlite3_column_type(handle, col - 1)
@@ -298,7 +303,7 @@ function load!(
298303
st = nothing;
299304
temp::Bool = false,
300305
ifnotexists::Bool = false,
301-
on_conflict::Union{String, Nothing} = nothing,
306+
on_conflict::Union{String,Nothing} = nothing,
302307
replace::Bool = false,
303308
analyze::Bool = false,
304309
)
@@ -313,7 +318,9 @@ function load!(
313318
# build insert statement
314319
columns = join(esc_id.(string.(sch.names)), ",")
315320
params = chop(repeat("?,", length(sch.names)))
316-
kind = isnothing(on_conflict) ? (replace ? "REPLACE" : "INSERT") : "INSERT OR $on_conflict"
321+
kind =
322+
isnothing(on_conflict) ? (replace ? "REPLACE" : "INSERT") :
323+
"INSERT OR $on_conflict"
317324
stmt = Stmt(
318325
db,
319326
"$kind INTO $(esc_id(string(name))) ($columns) VALUES ($params)";

0 commit comments

Comments
 (0)