Skip to content

PYTHON-5257 - Turn on mypy disallow_any_generics #2456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions bson/json_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any:
return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}}


def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if (
json_options.datetime_representation == DatetimeRepresentation.ISO8601
and 0 <= int(obj) <= _MAX_UTC_MS
Expand All @@ -855,7 +855,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
return {"$date": {"$numberLong": str(int(obj))}}


def _encode_code(obj: Code, json_options: JSONOptions) -> dict:
def _encode_code(obj: Code, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if obj.scope is None:
return {"$code": str(obj)}
else:
Expand All @@ -873,7 +873,7 @@ def _encode_noop(obj: Any, dummy0: Any) -> Any:
return obj


def _encode_regex(obj: Any, json_options: JSONOptions) -> dict:
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
flags = ""
if obj.flags & re.IGNORECASE:
flags += "i"
Expand Down Expand Up @@ -918,7 +918,7 @@ def _encode_float(obj: float, json_options: JSONOptions) -> Any:
return obj


def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
if not obj.tzinfo:
obj = obj.replace(tzinfo=utc)
Expand All @@ -941,51 +941,51 @@ def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
return {"$date": {"$numberLong": str(millis)}}


def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict:
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _encode_binary(obj, 0, json_options)


def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict:
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _encode_binary(obj, obj.subtype, json_options)


def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if json_options.strict_uuid:
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
return _encode_binary(binval, binval.subtype, json_options)
else:
return {"$uuid": obj.hex}


def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict:
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$oid": str(obj)}


def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict:
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$timestamp": {"t": obj.time, "i": obj.inc}}


def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict:
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$numberDecimal": str(obj)}


def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict:
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _json_convert(obj.as_doc(), json_options=json_options)


def _encode_minkey(dummy0: Any, dummy1: Any) -> dict:
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
return {"$minKey": 1}


def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
return {"$maxKey": 1}


# Encoders for BSON types
# Each encoder function's signature is:
# - obj: a Python data type, e.g. a Python int for _encode_int
# - json_options: a JSONOptions
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = {
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { # type: ignore[type-arg]
bool: _encode_noop,
bytes: _encode_bytes,
datetime.datetime: _encode_datetime,
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def _get_datetime_size(obj: datetime.datetime) -> int:
return 5 + len(str(obj.time()))


def _get_regex_size(obj: Regex) -> int:
def _get_regex_size(obj: Regex) -> int: # type: ignore[type-arg]
return 18 + len(obj.pattern)


Expand Down
2 changes: 1 addition & 1 deletion bson/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]
28 changes: 14 additions & 14 deletions gridfs/asynchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
class AsyncGridFS:
"""An instance of GridFS on top of a single Database."""

def __init__(self, database: AsyncDatabase, collection: str = "fs"):
def __init__(self, database: AsyncDatabase[Any], collection: str = "fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
Expand Down Expand Up @@ -463,7 +463,7 @@ class AsyncGridFSBucket:

def __init__(
self,
db: AsyncDatabase,
db: AsyncDatabase[Any],
bucket_name: str = "fs",
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
write_concern: Optional[WriteConcern] = None,
Expand Down Expand Up @@ -513,11 +513,11 @@ def __init__(

self._bucket_name = bucket_name
self._collection = db[bucket_name]
self._chunks: AsyncCollection = self._collection.chunks.with_options(
self._chunks: AsyncCollection[Any] = self._collection.chunks.with_options(
write_concern=write_concern, read_preference=read_preference
)

self._files: AsyncCollection = self._collection.files.with_options(
self._files: AsyncCollection[Any] = self._collection.files.with_options(
write_concern=write_concern, read_preference=read_preference
)

Expand Down Expand Up @@ -1085,7 +1085,7 @@ class AsyncGridIn:

def __init__(
self,
root_collection: AsyncCollection,
root_collection: AsyncCollection[Any],
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def __init__(
"""
if not isinstance(root_collection, AsyncCollection):
raise TypeError(
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
f"root_collection must be an instance of AsyncCollection[Any], not {type(root_collection)}"
)

if not root_collection.write_concern.acknowledged:
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def __init__(
object.__setattr__(self, "_buffered_docs_size", 0)

async def _create_index(
self, collection: AsyncCollection, index_key: Any, unique: bool
self, collection: AsyncCollection[Any], index_key: Any, unique: bool
) -> None:
doc = await collection.find_one(projection={"_id": 1}, session=self._session)
if doc is None:
Expand Down Expand Up @@ -1456,7 +1456,7 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore

def __init__(
self,
root_collection: AsyncCollection,
root_collection: AsyncCollection[Any],
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[AsyncClientSession] = None,
Expand Down Expand Up @@ -1494,7 +1494,7 @@ def __init__(
"""
if not isinstance(root_collection, AsyncCollection):
raise TypeError(
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
f"root_collection must be an instance of AsyncCollection[Any], not {type(root_collection)}"
)
_disallow_transactions(session)

Expand Down Expand Up @@ -1829,7 +1829,7 @@ class _AsyncGridOutChunkIterator:
def __init__(
self,
grid_out: AsyncGridOut,
chunks: AsyncCollection,
chunks: AsyncCollection[Any],
session: Optional[AsyncClientSession],
next_chunk: Any,
) -> None:
Expand All @@ -1842,7 +1842,7 @@ def __init__(
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None

_cursor: Optional[AsyncCursor]
_cursor: Optional[AsyncCursor[Any]]

def expected_chunk_length(self, chunk_n: int) -> int:
if chunk_n < self._num_chunks - 1:
Expand Down Expand Up @@ -1921,7 +1921,7 @@ async def close(self) -> None:

class AsyncGridOutIterator:
def __init__(
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
self, grid_out: AsyncGridOut, chunks: AsyncCollection[Any], session: AsyncClientSession
):
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)

Expand All @@ -1935,14 +1935,14 @@ async def next(self) -> bytes:
__anext__ = next


class AsyncGridOutCursor(AsyncCursor):
class AsyncGridOutCursor(AsyncCursor): # type: ignore[type-arg]
"""A cursor / iterator for returning GridOut objects as the result
of an arbitrary query against the GridFS files collection.
"""

def __init__(
self,
collection: AsyncCollection,
collection: AsyncCollection[Any],
filter: Optional[Mapping[str, Any]] = None,
skip: int = 0,
limit: int = 0,
Expand Down
28 changes: 14 additions & 14 deletions gridfs/synchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None:
class GridFS:
"""An instance of GridFS on top of a single Database."""

def __init__(self, database: Database, collection: str = "fs"):
def __init__(self, database: Database[Any], collection: str = "fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
Expand Down Expand Up @@ -461,7 +461,7 @@ class GridFSBucket:

def __init__(
self,
db: Database,
db: Database[Any],
bucket_name: str = "fs",
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
write_concern: Optional[WriteConcern] = None,
Expand Down Expand Up @@ -511,11 +511,11 @@ def __init__(

self._bucket_name = bucket_name
self._collection = db[bucket_name]
self._chunks: Collection = self._collection.chunks.with_options(
self._chunks: Collection[Any] = self._collection.chunks.with_options(
write_concern=write_concern, read_preference=read_preference
)

self._files: Collection = self._collection.files.with_options(
self._files: Collection[Any] = self._collection.files.with_options(
write_concern=write_concern, read_preference=read_preference
)

Expand Down Expand Up @@ -1077,7 +1077,7 @@ class GridIn:

def __init__(
self,
root_collection: Collection,
root_collection: Collection[Any],
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def __init__(
"""
if not isinstance(root_collection, Collection):
raise TypeError(
f"root_collection must be an instance of Collection, not {type(root_collection)}"
f"root_collection must be an instance of Collection[Any], not {type(root_collection)}"
)

if not root_collection.write_concern.acknowledged:
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def __init__(
object.__setattr__(self, "_buffered_docs", [])
object.__setattr__(self, "_buffered_docs_size", 0)

def _create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
def _create_index(self, collection: Collection[Any], index_key: Any, unique: bool) -> None:
doc = collection.find_one(projection={"_id": 1}, session=self._session)
if doc is None:
try:
Expand Down Expand Up @@ -1444,7 +1444,7 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore

def __init__(
self,
root_collection: Collection,
root_collection: Collection[Any],
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[ClientSession] = None,
Expand Down Expand Up @@ -1482,7 +1482,7 @@ def __init__(
"""
if not isinstance(root_collection, Collection):
raise TypeError(
f"root_collection must be an instance of Collection, not {type(root_collection)}"
f"root_collection must be an instance of Collection[Any], not {type(root_collection)}"
)
_disallow_transactions(session)

Expand Down Expand Up @@ -1817,7 +1817,7 @@ class GridOutChunkIterator:
def __init__(
self,
grid_out: GridOut,
chunks: Collection,
chunks: Collection[Any],
session: Optional[ClientSession],
next_chunk: Any,
) -> None:
Expand All @@ -1830,7 +1830,7 @@ def __init__(
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None

_cursor: Optional[Cursor]
_cursor: Optional[Cursor[Any]]

def expected_chunk_length(self, chunk_n: int) -> int:
if chunk_n < self._num_chunks - 1:
Expand Down Expand Up @@ -1908,7 +1908,7 @@ def close(self) -> None:


class GridOutIterator:
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
def __init__(self, grid_out: GridOut, chunks: Collection[Any], session: ClientSession):
self._chunk_iter = GridOutChunkIterator(grid_out, chunks, session, 0)

def __iter__(self) -> GridOutIterator:
Expand All @@ -1921,14 +1921,14 @@ def next(self) -> bytes:
__next__ = next


class GridOutCursor(Cursor):
class GridOutCursor(Cursor): # type: ignore[type-arg]
"""A cursor / iterator for returning GridOut objects as the result
of an arbitrary query against the GridFS files collection.
"""

def __init__(
self,
collection: Collection,
collection: Collection[Any],
filter: Optional[Mapping[str, Any]] = None,
skip: int = 0,
limit: int = 0,
Expand Down
6 changes: 3 additions & 3 deletions pymongo/_asyncio_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
"""

def __init__(self) -> None:
self._waiters: Optional[collections.deque] = None
self._waiters: Optional[collections.deque[Any]] = None
self._locked = False

def __repr__(self) -> str:
Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(self, lock: Optional[Lock] = None) -> None:
self.acquire = lock.acquire
self.release = lock.release

self._waiters: collections.deque = collections.deque()
self._waiters: collections.deque[Any] = collections.deque()

def __repr__(self) -> str:
res = super().__repr__()
Expand Down Expand Up @@ -260,7 +260,7 @@ async def wait(self) -> bool:
self._notify(1)
raise

async def wait_for(self, predicate: Any) -> Coroutine:
async def wait_for(self, predicate: Any) -> Coroutine[Any, Any, Any]:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
Expand Down
4 changes: 2 additions & 2 deletions pymongo/_asyncio_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
class _Task(asyncio.Task[Any]):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
Expand All @@ -43,7 +43,7 @@ def cancelling(self) -> int:
return self._cancel_requests


def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task[Any]:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)
Loading
Loading