Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Implement `insert_many` and `insert_many_tx`. [PR #22](https://github.com/riverqueue/river/pull/22).

## [0.2.0] - 2024-07-04

### Changed
Expand Down
6 changes: 6 additions & 0 deletions examples/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
import asyncio

from examples import async_client_insert_example
from examples import async_client_insert_many_example
from examples import async_client_insert_tx_example
from examples import client_insert_example
from examples import client_insert_many_example
from examples import client_insert_many_insert_opts_example
from examples import client_insert_tx_example

if __name__ == "__main__":
asyncio.set_event_loop(asyncio.new_event_loop())

asyncio.run(async_client_insert_example.example())
asyncio.run(async_client_insert_many_example.example())
asyncio.run(async_client_insert_tx_example.example())

client_insert_example.example()
client_insert_many_example.example()
client_insert_many_insert_opts_example.example()
client_insert_tx_example.example()
42 changes: 42 additions & 0 deletions examples/async_client_insert_many_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Run with:
#
# rye run python3 -m examples.client_insert_many_example
#

import asyncio
from dataclasses import dataclass
import json
import riverqueue
import sqlalchemy

from examples.helpers import dev_database_url
from riverqueue.driver import riversqlalchemy


@dataclass
class CountArgs:
count: int

kind: str = "sort"

def to_json(self) -> str:
return json.dumps({"count": self.count})


async def example():
engine = sqlalchemy.ext.asyncio.create_async_engine(dev_database_url(is_async=True))
client = riverqueue.AsyncClient(riversqlalchemy.AsyncDriver(engine))

num_inserted = await client.insert_many(
[
CountArgs(count=1),
CountArgs(count=2),
]
)
print(num_inserted)


if __name__ == "__main__":
asyncio.set_event_loop(asyncio.new_event_loop())
asyncio.run(example())
40 changes: 40 additions & 0 deletions examples/client_insert_many_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Run with:
#
# rye run python3 -m examples.client_insert_many_example
#

from dataclasses import dataclass
import json
import riverqueue
import sqlalchemy

from examples.helpers import dev_database_url
from riverqueue.driver import riversqlalchemy


@dataclass
class CountArgs:
count: int

kind: str = "sort"

def to_json(self) -> str:
return json.dumps({"count": self.count})


def example():
engine = sqlalchemy.create_engine(dev_database_url())
client = riverqueue.Client(riversqlalchemy.Driver(engine))

num_inserted = client.insert_many(
[
CountArgs(count=1),
CountArgs(count=2),
]
)
print(num_inserted)


if __name__ == "__main__":
example()
46 changes: 46 additions & 0 deletions examples/client_insert_many_insert_opts_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Run with:
#
# rye run python3 -m examples.client_insert_many_example
#

from dataclasses import dataclass
import json
import riverqueue
import sqlalchemy

from examples.helpers import dev_database_url
from riverqueue.driver import riversqlalchemy


@dataclass
class CountArgs:
count: int

kind: str = "sort"

def to_json(self) -> str:
return json.dumps({"count": self.count})


def example():
engine = sqlalchemy.create_engine(dev_database_url())
client = riverqueue.Client(riversqlalchemy.Driver(engine))

num_inserted = client.insert_many(
[
riverqueue.InsertManyParams(
CountArgs(count=1),
insert_opts=riverqueue.InsertOpts(max_attempts=5),
),
riverqueue.InsertManyParams(
CountArgs(count=2),
insert_opts=riverqueue.InsertOpts(queue="alternate_queue"),
),
]
)
print(num_inserted)


if __name__ == "__main__":
example()
2 changes: 1 addition & 1 deletion src/riverqueue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _make_insert_params(
queue=insert_opts.queue or args_insert_opts.queue or QUEUE_DEFAULT,
scheduled_at=scheduled_at and scheduled_at.astimezone(timezone.utc),
state="scheduled" if scheduled_at else "available",
tags=insert_opts.tags or args_insert_opts.tags,
tags=insert_opts.tags or args_insert_opts.tags or [],
)

return insert_params, unique_opts
Expand Down
12 changes: 6 additions & 6 deletions src/riverqueue/driver/driver_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ class GetParams:
@dataclass
class JobInsertParams:
kind: str
args: Optional[Any] = None
args: Any = None
created_at: Optional[datetime] = None
finalized_at: Optional[datetime] = None
metadata: Optional[Any] = None
max_attempts: Optional[int] = field(default=25)
priority: Optional[int] = field(default=1)
queue: Optional[str] = field(default="default")
max_attempts: int = field(default=25)
priority: int = field(default=1)
queue: str = field(default="default")
scheduled_at: Optional[datetime] = None
state: Optional[str] = field(default="available")
tags: Optional[List[str]] = field(default_factory=list)
state: str = field(default="available")
tags: list[str] = field(default_factory=list)


class AsyncExecutorProtocol(Protocol):
Expand Down
68 changes: 68 additions & 0 deletions src/riverqueue/driver/riversqlalchemy/dbsqlc/river_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,46 @@ class JobInsertFastParams:
tags: List[str]


JOB_INSERT_FAST_MANY = """-- name: job_insert_fast_many \\:execrows
INSERT INTO river_job(
args,
kind,
max_attempts,
metadata,
priority,
queue,
scheduled_at,
state,
tags
) SELECT
unnest(:p1\\:\\:jsonb[]),
unnest(:p2\\:\\:text[]),
unnest(:p3\\:\\:smallint[]),
unnest(:p4\\:\\:jsonb[]),
unnest(:p5\\:\\:smallint[]),
unnest(:p6\\:\\:text[]),
unnest(:p7\\:\\:timestamptz[]),
unnest(:p8\\:\\:river_job_state[]),
-- Had trouble getting multi-dimensional arrays to play nicely with sqlc,
-- but it might be possible. For now, join tags into a single string.
string_to_array(unnest(:p9\\:\\:text[]), ',')
"""


@dataclasses.dataclass()
class JobInsertFastManyParams:
args: List[Any]
kind: List[str]
max_attempts: List[int]
metadata: List[Any]
priority: List[int]
queue: List[str]
scheduled_at: List[datetime.datetime]
state: List[models.RiverJobState]
tags: List[str]


class Querier:
def __init__(self, conn: sqlalchemy.engine.Connection):
self._conn = conn
Expand Down Expand Up @@ -154,6 +194,20 @@ def job_insert_fast(self, arg: JobInsertFastParams) -> Optional[models.RiverJob]
tags=row[15],
)

def job_insert_fast_many(self, arg: JobInsertFastManyParams) -> int:
result = self._conn.execute(sqlalchemy.text(JOB_INSERT_FAST_MANY), {
"p1": arg.args,
"p2": arg.kind,
"p3": arg.max_attempts,
"p4": arg.metadata,
"p5": arg.priority,
"p6": arg.queue,
"p7": arg.scheduled_at,
"p8": arg.state,
"p9": arg.tags,
})
return result.rowcount


class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
Expand Down Expand Up @@ -227,3 +281,17 @@ async def job_insert_fast(self, arg: JobInsertFastParams) -> Optional[models.Riv
scheduled_at=row[14],
tags=row[15],
)

async def job_insert_fast_many(self, arg: JobInsertFastManyParams) -> int:
result = await self._conn.execute(sqlalchemy.text(JOB_INSERT_FAST_MANY), {
"p1": arg.args,
"p2": arg.kind,
"p3": arg.max_attempts,
"p4": arg.metadata,
"p5": arg.priority,
"p6": arg.queue,
"p7": arg.scheduled_at,
"p8": arg.state,
"p9": arg.tags,
})
return result.rowcount
27 changes: 26 additions & 1 deletion src/riverqueue/driver/riversqlalchemy/dbsqlc/river_job.sql
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,29 @@ INSERT INTO river_job(
coalesce(sqlc.narg('scheduled_at')::timestamptz, now()),
@state::river_job_state,
coalesce(@tags::varchar(255)[], '{}')
) RETURNING *;
) RETURNING *;

-- name: JobInsertFastMany :execrows
INSERT INTO river_job(
args,
kind,
max_attempts,
metadata,
priority,
queue,
scheduled_at,
state,
tags
) SELECT
unnest(@args::jsonb[]),
unnest(@kind::text[]),
unnest(@max_attempts::smallint[]),
unnest(@metadata::jsonb[]),
unnest(@priority::smallint[]),
unnest(@queue::text[]),
unnest(@scheduled_at::timestamptz[]),
unnest(@state::river_job_state[]),

-- Had trouble getting multi-dimensional arrays to play nicely with sqlc,
-- but it might be possible. For now, join tags into a single string.
string_to_array(unnest(@tags::text[]), ',');
46 changes: 41 additions & 5 deletions src/riverqueue/driver/riversqlalchemy/sql_alchemy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
asynccontextmanager,
contextmanager,
)
from datetime import datetime, timezone
from riverqueue.driver.driver_protocol import AsyncDriverProtocol, AsyncExecutorProtocol
from sqlalchemy import Engine
from sqlalchemy.engine import Connection
Expand All @@ -16,7 +17,7 @@

from ...driver import DriverProtocol, ExecutorProtocol, GetParams, JobInsertParams
from ...model import Job
from .dbsqlc import river_job, pg_misc
from .dbsqlc import models, river_job, pg_misc


class AsyncExecutor(AsyncExecutorProtocol):
Expand All @@ -36,8 +37,11 @@ async def job_insert(self, insert_params: JobInsertParams) -> Job:
),
)

async def job_insert_many(self, all_params) -> int:
raise NotImplementedError("sqlc doesn't implement copy in python yet")
async def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
await self.job_querier.job_insert_fast_many(
_build_insert_many_params(all_params)
)
return len(all_params)

async def job_get_by_kind_and_unique_properties(
self, get_params: GetParams
Expand Down Expand Up @@ -94,8 +98,9 @@ def job_insert(self, insert_params: JobInsertParams) -> Job:
),
)

def job_insert_many(self, all_params) -> int:
raise NotImplementedError("sqlc doesn't implement copy in python yet")
def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
self.job_querier.job_insert_fast_many(_build_insert_many_params(all_params))
return len(all_params)

def job_get_by_kind_and_unique_properties(
self, get_params: GetParams
Expand Down Expand Up @@ -133,3 +138,34 @@ def executor(self) -> Iterator[ExecutorProtocol]:

def unwrap_executor(self, tx) -> ExecutorProtocol:
return Executor(tx)


def _build_insert_many_params(
all_params: list[JobInsertParams],
) -> river_job.JobInsertFastManyParams:
insert_many_params = river_job.JobInsertFastManyParams(
args=[],
kind=[],
max_attempts=[],
metadata=[],
priority=[],
queue=[],
scheduled_at=[],
state=[],
tags=[],
)

for insert_params in all_params:
insert_many_params.args.append(insert_params.args)
insert_many_params.kind.append(insert_params.kind)
insert_many_params.max_attempts.append(insert_params.max_attempts)
insert_many_params.metadata.append(insert_params.metadata or "{}")
insert_many_params.priority.append(insert_params.priority)
insert_many_params.queue.append(insert_params.queue)
insert_many_params.scheduled_at.append(
insert_params.scheduled_at or datetime.now(timezone.utc)
)
insert_many_params.state.append(cast(models.RiverJobState, insert_params.state))
insert_many_params.tags.append(",".join(insert_params.tags))

return insert_many_params
Loading