diff --git a/src/tables.jl b/src/tables.jl index 0bd56cb..8e2ae27 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -8,6 +8,7 @@ struct Query names::Vector{Symbol} types::Vector{Type} lookup::Dict{Symbol, Int} + current_rownumber::Base.RefValue{Int} end # check if the query has no (more) rows @@ -15,6 +16,7 @@ Base.isempty(q::Query) = q.status[] == SQLITE_DONE struct Row <: Tables.AbstractRow q::Query + rownumber::Int end getquery(r::Row) = getfield(r, :q) @@ -56,7 +58,10 @@ function done(q::Query) return false end -function getvalue(q::Query, col::Int, ::Type{T}) where {T} +@noinline wrongrow(i) = throw(ArgumentError("row $i is no longer valid; sqlite query results are forward-only iterators where each row is only valid when iterated; re-execute the query, convert rows to NamedTuples, or stream the results to a sink to save results")) + +function getvalue(q::Query, col::Int, rownumber::Int, ::Type{T}) where {T} + rownumber == q.current_rownumber[] || wrongrow(rownumber) handle = _stmt(q.stmt).handle t = sqlite3_column_type(handle, col) if t == SQLITE_NULL @@ -67,7 +72,7 @@ function getvalue(q::Query, col::Int, ::Type{T}) where {T} end end -Tables.getcolumn(r::Row, ::Type{T}, i::Int, nm::Symbol) where {T} = getvalue(getquery(r), i, T) +Tables.getcolumn(r::Row, ::Type{T}, i::Int, nm::Symbol) where {T} = getvalue(getquery(r), i, getfield(r, :rownumber), T) Tables.getcolumn(r::Row, i::Int) = Tables.getcolumn(r, getquery(r).types[i], i, getquery(r).names[i]) Tables.getcolumn(r::Row, nm::Symbol) = Tables.getcolumn(r, getquery(r).lookup[nm]) @@ -75,13 +80,15 @@ Tables.columnnames(r::Row) = Tables.columnnames(getquery(r)) function Base.iterate(q::Query) done(q) && return nothing - return Row(q), nothing + q.current_rownumber[] = 1 + return Row(q, 1), 2 end -function Base.iterate(q::Query, ::Nothing) +function Base.iterate(q::Query, rownumber) q.status[] = sqlite3_step(_stmt(q.stmt).handle) done(q) && return nothing - return Row(q), nothing + q.current_rownumber[] = rownumber + return Row(q, rownumber), rownumber + 1 end "Return the last row insert id from the executed statement" @@ -129,7 +136,7 @@ function DBInterface.execute(stmt::Stmt, params::DBInterface.StatementParams; al header[i] = nm types[i] = Union{juliatype(_st.handle, i), Missing} end - return Query(stmt, Ref(status), header, types, Dict(x=>i for (i, x) in enumerate(header))) + return Query(stmt, Ref(status), header, types, Dict(x=>i for (i, x) in enumerate(header)), Ref(0)) end """ diff --git a/test/runtests.jl b/test/runtests.jl index 05d9aed..72c753b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -540,4 +540,11 @@ tbl = DBInterface.execute(db, "select * from tbl") |> columntable c = [3, 3, 3] ) +# https://github.com/JuliaDatabases/SQLite.jl/issues/251 +q = DBInterface.execute(db, "select * from tbl") +row, st = iterate(q) +@test row.a == 1 && row.b == 2 && row.c == 3 +row2, st = iterate(q, st) +@test_throws ArgumentError row.a + end \ No newline at end of file