From 16728cc304ce5cc8f2f40235935cca561795e8cf Mon Sep 17 00:00:00 2001 From: Christopher Rink Date: Tue, 9 Jan 2024 08:37:22 -0500 Subject: [PATCH] Improve the state of Python type hints in `basilisp.lang.*` --- CHANGELOG.md | 4 + pyproject.toml | 2 + src/basilisp/cli.py | 33 ++++--- src/basilisp/importer.py | 25 +++-- src/basilisp/lang/atom.py | 6 +- src/basilisp/lang/compiler/analyzer.py | 48 +++++----- src/basilisp/lang/compiler/exception.py | 4 +- src/basilisp/lang/compiler/generator.py | 8 +- src/basilisp/lang/compiler/nodes.py | 118 +++++++++++------------- src/basilisp/lang/delay.py | 10 +- src/basilisp/lang/exception.py | 8 +- src/basilisp/lang/futures.py | 18 ++-- src/basilisp/lang/interfaces.py | 10 +- src/basilisp/lang/multifn.py | 27 +++--- src/basilisp/lang/obj.py | 7 +- src/basilisp/lang/promise.py | 2 +- src/basilisp/lang/reader.py | 20 ++-- src/basilisp/lang/reduced.py | 2 +- src/basilisp/lang/reference.py | 18 +--- src/basilisp/lang/runtime.py | 45 ++++----- src/basilisp/lang/seq.py | 17 +++- src/basilisp/lang/symbol.py | 6 +- src/basilisp/lang/vector.py | 7 +- src/basilisp/lang/volatile.py | 8 +- src/basilisp/util.py | 15 +-- tests/basilisp/reader_test.py | 3 + 26 files changed, 244 insertions(+), 227 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4a495754..5d9fb138a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed * Removed support for PyPy 3.8 (#785) +### Other + * Improve the state of the Python type hints in `basilisp.lang.*` (#???) + + ## [v0.1.0b0] ### Added * Added rudimentary support for `clojure.stacktrace` with `print-cause-trace` (part of #721) diff --git a/pyproject.toml b/pyproject.toml index 28a08df54..db270a5f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ prompt-toolkit = "^3.0.0" pyrsistent = "^0.18.0" python-dateutil = "^2.8.1" readerwriterlock = "^1.0.8" +typing_extensions = "^4.9.0" astor = { version = "^0.8.1", python = "<3.9", optional = true } pytest = { version = "^7.0.0", optional = true } @@ -219,6 +220,7 @@ disable = [ check_untyped_defs = true mypy_path = "src/" show_error_codes = true +warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true diff --git a/src/basilisp/cli.py b/src/basilisp/cli.py index b0fd4b299..3163b4aff 100644 --- a/src/basilisp/cli.py +++ b/src/basilisp/cli.py @@ -237,8 +237,13 @@ def _subcommand( help: Optional[str] = None, # pylint: disable=redefined-builtin description: Optional[str] = None, handler: Handler, -): - def _wrap_add_subcommand(f: Callable[[argparse.ArgumentParser], None]): +) -> Callable[ + [Callable[[argparse.ArgumentParser], None]], + Callable[["argparse._SubParsersAction"], None], +]: + def _wrap_add_subcommand( + f: Callable[[argparse.ArgumentParser], None] + ) -> Callable[["argparse._SubParsersAction"], None]: def _wrapped_subcommand(subparsers: "argparse._SubParsersAction"): parser = subparsers.add_parser( subcommand, help=help, description=description @@ -279,14 +284,14 @@ def bootstrap_basilisp_installation(_, args: argparse.Namespace) -> None: description=textwrap.dedent( """Bootstrap the Python installation to allow importing Basilisp namespaces" without requiring an additional bootstrapping step. - + Python installations are bootstrapped by installing a `basilispbootstrap.pth` file in your `site-packages` directory. Python installations execute `*.pth` files found at startup. - + Bootstrapping your Python installation in this way can help avoid needing to perform manual bootstrapping from Python code within your application. - + On the first startup, Basilisp will compile `basilisp.core` to byte code which could take up to 30 seconds in some cases depending on your system and which version of Python you are using. Subsequent startups should be @@ -319,7 +324,7 @@ def _add_bootstrap_subcommand(parser: argparse.ArgumentParser) -> None: def nrepl_server( _, args: argparse.Namespace, -): +) -> None: opts = compiler.compiler_opts() basilisp.init(opts) @@ -369,7 +374,7 @@ def _add_nrepl_server_subcommand(parser: argparse.ArgumentParser) -> None: def repl( _, args: argparse.Namespace, -): +) -> None: opts = compiler.compiler_opts( warn_on_shadowed_name=args.warn_on_shadowed_name, warn_on_shadowed_var=args.warn_on_shadowed_var, @@ -465,7 +470,7 @@ def _add_repl_subcommand(parser: argparse.ArgumentParser) -> None: def run( parser: argparse.ArgumentParser, args: argparse.Namespace, -): +) -> None: target = args.file_or_ns_or_code if args.load_namespace: if args.in_ns is not None: @@ -523,18 +528,18 @@ def run( help="run a Basilisp script or code or namespace", description=textwrap.dedent( """Run a Basilisp script or a line of code or load a Basilisp namespace. - + If `-c` is provided, execute the line of code as given. If `-n` is given, interpret `file_or_ns_or_code` as a fully qualified Basilisp namespace relative to `sys.path`. Otherwise, execute the file as a script relative to the current working directory. - + `*main-ns*` will be set to the value provided for `-n`. In all other cases, it will be `nil`.""" ), handler=run, ) -def _add_run_subcommand(parser: argparse.ArgumentParser): +def _add_run_subcommand(parser: argparse.ArgumentParser) -> None: parser.add_argument( "file_or_ns_or_code", help=( @@ -570,7 +575,9 @@ def _add_run_subcommand(parser: argparse.ArgumentParser): _add_debug_arg_group(parser) -def test(parser: argparse.ArgumentParser, args: argparse.Namespace): # pragma: no cover +def test( + parser: argparse.ArgumentParser, args: argparse.Namespace +) -> None: # pragma: no cover try: import pytest except (ImportError, ModuleNotFoundError): @@ -591,7 +598,7 @@ def _add_test_subcommand(parser: argparse.ArgumentParser) -> None: parser.add_argument("args", nargs=-1) -def version(_, __): +def version(_, __) -> None: v = importlib.metadata.version("basilisp") print(f"Basilisp {v}") diff --git a/src/basilisp/importer.py b/src/basilisp/importer.py index 526bff212..de307eb4f 100644 --- a/src/basilisp/importer.py +++ b/src/basilisp/importer.py @@ -8,7 +8,16 @@ from functools import lru_cache from importlib.abc import MetaPathFinder, SourceLoader from importlib.machinery import ModuleSpec -from typing import Iterable, List, Mapping, MutableMapping, Optional, Sequence, cast +from typing import ( + Any, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + cast, +) from basilisp.lang import compiler as compiler from basilisp.lang import reader as reader @@ -191,22 +200,22 @@ def find_spec( return ModuleSpec(fullname, None, is_package=True) return None - def invalidate_caches(self): + def invalidate_caches(self) -> None: super().invalidate_caches() self._cache = {} - def _cache_bytecode(self, source_path, cache_path, data): + def _cache_bytecode(self, source_path: str, cache_path: str, data: bytes) -> None: self.set_data(cache_path, data) - def path_stats(self, path): + def path_stats(self, path: str) -> Mapping[str, Any]: stat = os.stat(path) return {"mtime": int(stat.st_mtime), "size": stat.st_size} - def get_data(self, path): + def get_data(self, path: str) -> bytes: with open(path, mode="r+b") as f: return f.read() - def set_data(self, path, data): + def set_data(self, path: str, data: bytes) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, mode="w+b") as f: f.write(data) @@ -279,7 +288,7 @@ def get_code(self, fullname: str) -> Optional[types.CodeType]: assert len(code) == 1 return code[0] - def create_module(self, spec: ModuleSpec): + def create_module(self, spec: ModuleSpec) -> BasilispModule: logger.debug(f"Creating Basilisp module '{spec.name}'") mod = BasilispModule(spec.name) mod.__file__ = spec.loader_state["filename"] @@ -400,7 +409,7 @@ def exec_module(self, module: types.ModuleType) -> None: self._exec_module(fullname, spec.loader_state, path_stats, ns) -def hook_imports(): +def hook_imports() -> None: """Hook into Python's import machinery with a custom Basilisp code importer. diff --git a/src/basilisp/lang/atom.py b/src/basilisp/lang/atom.py index 9cd2bb525..93f61830e 100644 --- a/src/basilisp/lang/atom.py +++ b/src/basilisp/lang/atom.py @@ -1,12 +1,14 @@ from typing import Callable, Generic, Optional, TypeVar from readerwriterlock.rwlock import RWLockFair +from typing_extensions import Concatenate, ParamSpec from basilisp.lang.interfaces import IPersistentMap, RefValidator from basilisp.lang.map import PersistentMap from basilisp.lang.reference import RefBase T = TypeVar("T") +P = ParamSpec("P") class Atom(RefBase[T], Generic[T]): @@ -58,7 +60,9 @@ def reset(self, v: T) -> T: self._notify_watches(oldval, v) return v - def swap(self, f: Callable[..., T], *args, **kwargs) -> T: + def swap( + self, f: Callable[Concatenate[T, P], T], *args: P.args, **kwargs: P.kwargs + ) -> T: """Atomically swap the state of the Atom to the return value of `f(old, *args, **kwargs)`, returning the new value.""" while True: diff --git a/src/basilisp/lang/compiler/analyzer.py b/src/basilisp/lang/compiler/analyzer.py index 65cc9e2dd..4d1629b00 100644 --- a/src/basilisp/lang/compiler/analyzer.py +++ b/src/basilisp/lang/compiler/analyzer.py @@ -28,6 +28,7 @@ Pattern, Set, Tuple, + TypeVar, Union, cast, ) @@ -175,13 +176,13 @@ AnalyzerException = partial(CompilerException, phase=CompilerPhase.ANALYZING) -@attr.s(auto_attribs=True, slots=True) +@attr.define class RecurPoint: loop_id: str args: Collection[Binding] = () -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class SymbolTableEntry: binding: Binding used: bool = False @@ -196,7 +197,7 @@ def context(self) -> LocalType: return self.binding.local -@attr.s(auto_attribs=True, slots=True) +@attr.define class SymbolTable: name: str _is_context_boundary: bool = False @@ -647,7 +648,12 @@ def get_meta_prop(o: Union[IMeta, Var]) -> Any: _tag_meta = _meta_getter(SYM_TAG_META_KEY) -def _loc(form: Union[LispForm, ISeq]) -> Optional[Tuple[int, int, int, int]]: +T_form = TypeVar("T_form", bound=ReaderForm) +T_node = TypeVar("T_node", bound=Node) +LispAnalyzer = Callable[[T_form, AnalyzerContext], T_node] + + +def _loc(form: T_form) -> Optional[Tuple[int, int, int, int]]: """Fetch the location of the form in the original filename from the input form, if it has metadata.""" # Technically, IMeta is sufficient for fetching `form.meta` but the @@ -669,17 +675,17 @@ def _loc(form: Union[LispForm, ISeq]) -> Optional[Tuple[int, int, int, int]]: return None -def _with_loc(f): +def _with_loc(f: LispAnalyzer[T_form, T_node]) -> LispAnalyzer[T_form, T_node]: """Attach any available location information from the input form to the node environment returned from the parsing function.""" @wraps(f) - def _analyze_form(form: Union[LispForm, ISeq], ctx: AnalyzerContext) -> Node: + def _analyze_form(form: T_form, ctx: AnalyzerContext) -> T_node: form_loc = _loc(form) if form_loc is None: return f(form, ctx) else: - return f(form, ctx).fix_missing_locations(form_loc) + return cast(T_node, f(form, ctx).fix_missing_locations(form_loc)) return _analyze_form @@ -795,7 +801,7 @@ def _tag_ast(form: Optional[LispForm], ctx: AnalyzerContext) -> Optional[Node]: return _analyze_form(form, ctx) -def _with_meta(gen_node): +def _with_meta(gen_node: LispAnalyzer[T_form, T_node]) -> LispAnalyzer[T_form, T_node]: """Wraps the node generated by gen_node in a :with-meta AST node if the original form has meta. @@ -803,16 +809,7 @@ def _with_meta(gen_node): function expressions.""" @wraps(gen_node) - def with_meta( - form: Union[ - llist.PersistentList, - lmap.PersistentMap, - ISeq, - lset.PersistentSet, - vec.PersistentVector, - ], - ctx: AnalyzerContext, - ) -> Node: + def with_meta(form: T_form, ctx: AnalyzerContext) -> T_node: assert not ctx.is_quoted, "with-meta nodes are not used in quoted expressions" descriptor = gen_node(form, ctx) @@ -825,11 +822,14 @@ def with_meta( assert isinstance(meta_ast, MapNode) or ( isinstance(meta_ast, Const) and meta_ast.type == ConstType.MAP ) - return WithMeta( - form=form, - meta=meta_ast, - expr=descriptor, - env=ctx.get_node_env(pos=ctx.syntax_position), + return cast( + T_node, + WithMeta( + form=cast(LispForm, form), + meta=meta_ast, + expr=descriptor, + env=ctx.get_node_env(pos=ctx.syntax_position), + ), ) return descriptor @@ -3113,7 +3113,7 @@ def _yield_ast(form: ISeq, ctx: AnalyzerContext) -> Yield: return Yield.expressionless(form, ctx.get_node_env(pos=ctx.syntax_position)) -SpecialFormHandler = Callable[[ISeq, AnalyzerContext], SpecialFormNode] +SpecialFormHandler = Callable[[T_form, AnalyzerContext], SpecialFormNode] _SPECIAL_FORM_HANDLERS: Mapping[sym.Symbol, SpecialFormHandler] = { SpecialForm.AWAIT: _await_ast, SpecialForm.DEF: _def_ast, diff --git a/src/basilisp/lang/compiler/exception.py b/src/basilisp/lang/compiler/exception.py index b6f5bc7af..697ddec45 100644 --- a/src/basilisp/lang/compiler/exception.py +++ b/src/basilisp/lang/compiler/exception.py @@ -30,7 +30,7 @@ class CompilerPhase(Enum): COMPILING_PYTHON = kw.keyword("compiling-python") -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class _loc: line: Optional[int] = None col: Optional[int] = None @@ -46,7 +46,7 @@ def __bool__(self): ) -@attr.s(auto_attribs=True, slots=True, str=False) +@attr.define(str=False) class CompilerException(IExceptionInfo): msg: str phase: CompilerPhase diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index bd7e2a2fa..21c6b0ec8 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -135,14 +135,14 @@ GeneratorException = partial(CompilerException, phase=CompilerPhase.CODE_GENERATION) -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class SymbolTableEntry: context: LocalType munged: str symbol: sym.Symbol -@attr.s(auto_attribs=True, slots=True) +@attr.define class SymbolTable: name: str _is_context_boundary: bool = False @@ -203,7 +203,7 @@ class RecurType(Enum): LOOP = kw.keyword("loop") -@attr.s(auto_attribs=True, slots=True) +@attr.define class RecurPoint: loop_id: str type: RecurType @@ -313,7 +313,7 @@ def new_this(self, this: sym.Symbol): self._this.pop() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class GeneratedPyAST: node: ast.AST dependencies: Iterable[ast.AST] = () diff --git a/src/basilisp/lang/compiler/nodes.py b/src/basilisp/lang/compiler/nodes.py index 82fdf7a0d..fb53c161b 100644 --- a/src/basilisp/lang/compiler/nodes.py +++ b/src/basilisp/lang/compiler/nodes.py @@ -192,7 +192,7 @@ def visit(self, f: Callable[..., None], *args, **kwargs): def fix_missing_locations( self, form_loc: Optional[Tuple[int, int, int, int]] = None - ) -> "Node": + ) -> "Node[T]": """Return a transformed copy of this node with location in this node's environment updated to match the `form_loc` if given, or using its existing location otherwise. All child nodes will be recursively @@ -325,7 +325,7 @@ class LocalType(Enum): THIS = kw.keyword("this") -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class NodeEnv: ns: Namespace file: str @@ -337,7 +337,7 @@ class NodeEnv: func_ctx: Optional[FunctionContext] = None -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Await(Node[ReaderLispForm]): form: ReaderLispForm expr: Node @@ -348,7 +348,7 @@ class Await(Node[ReaderLispForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Binding(Node[sym.Symbol], Assignable): form: sym.Symbol name: str @@ -366,7 +366,7 @@ class Binding(Node[sym.Symbol], Assignable): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Catch(Node[SpecialForm]): form: SpecialForm class_: Union["MaybeClass", "MaybeHostForm"] @@ -379,7 +379,7 @@ class Catch(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Const(Node[ReaderLispForm]): form: ReaderLispForm type: ConstType @@ -393,7 +393,7 @@ class Const(Node[ReaderLispForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Def(Node[SpecialForm]): form: SpecialForm name: sym.Symbol @@ -412,7 +412,7 @@ class Def(Node[SpecialForm]): DefTypeBase = Union["MaybeClass", "MaybeHostForm", "VarRef"] -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefType(Node[SpecialForm]): form: SpecialForm name: str @@ -434,7 +434,7 @@ def python_member_names(self) -> Iterable[str]: yield from deftype_or_reify_python_member_names(self.members) -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeMember(Node[SpecialForm]): form: SpecialForm name: str @@ -445,7 +445,7 @@ def python_name(self) -> str: return munge(self.name) -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeClassMethod(DefTypeMember): class_local: Binding params: Iterable[Binding] @@ -459,7 +459,7 @@ class DefTypeClassMethod(DefTypeMember): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeMethod(DefTypeMember): max_fixed_arity: int arities: IPersistentVector["DefTypeMethodArity"] @@ -470,7 +470,7 @@ class DefTypeMethod(DefTypeMember): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeMethodArity(Node[SpecialForm]): form: SpecialForm name: str @@ -492,7 +492,7 @@ def python_name(self) -> str: return f"_{munge(self.name)}_arity{'_rest' if self.is_variadic else self.fixed_arity}" -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeProperty(DefTypeMember): this_local: Binding params: Iterable[Binding] @@ -503,7 +503,7 @@ class DefTypeProperty(DefTypeMember): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class DefTypeStaticMethod(DefTypeMember): params: Iterable[Binding] fixed_arity: int @@ -519,7 +519,7 @@ class DefTypeStaticMethod(DefTypeMember): DefTypePythonMember = Union[DefTypeClassMethod, DefTypeProperty, DefTypeStaticMethod] -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Do(Node[SpecialForm]): form: SpecialForm statements: Iterable[Node] @@ -532,7 +532,7 @@ class Do(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Fn(Node[SpecialForm]): form: SpecialForm max_fixed_arity: int @@ -549,7 +549,7 @@ class Fn(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class FnArity(Node[SpecialForm]): form: SpecialForm loop_id: LoopID @@ -565,7 +565,7 @@ class FnArity(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class HostCall(Node[SpecialForm]): form: SpecialForm method: str @@ -579,7 +579,7 @@ class HostCall(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class HostField(Node[Union[SpecialForm, sym.Symbol]], Assignable): form: Union[SpecialForm, sym.Symbol] field: str @@ -592,7 +592,7 @@ class HostField(Node[Union[SpecialForm, sym.Symbol]], Assignable): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class If(Node[SpecialForm]): form: SpecialForm test: Node @@ -605,7 +605,7 @@ class If(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Import(Node[SpecialForm]): form: SpecialForm aliases: Iterable["ImportAlias"] @@ -616,7 +616,7 @@ class Import(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class ImportAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): form: Union[sym.Symbol, vec.PersistentVector] name: str @@ -628,7 +628,7 @@ class ImportAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Invoke(Node[SpecialForm]): form: SpecialForm fn: Node @@ -641,7 +641,7 @@ class Invoke(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Let(Node[SpecialForm]): form: SpecialForm bindings: Iterable[Binding] @@ -653,7 +653,7 @@ class Let(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class LetFn(Node[SpecialForm]): form: SpecialForm bindings: Iterable[Binding] @@ -665,7 +665,7 @@ class LetFn(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Local(Node[sym.Symbol], Assignable): form: sym.Symbol name: str @@ -680,7 +680,7 @@ class Local(Node[sym.Symbol], Assignable): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Loop(Node[SpecialForm]): form: SpecialForm bindings: Iterable[Binding] @@ -693,7 +693,7 @@ class Loop(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Map(Node[IPersistentMap]): form: IPersistentMap keys: Iterable[Node] @@ -705,7 +705,7 @@ class Map(Node[IPersistentMap]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class MaybeClass(Node[sym.Symbol]): form: sym.Symbol class_: str @@ -717,7 +717,7 @@ class MaybeClass(Node[sym.Symbol]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class MaybeHostForm(Node[sym.Symbol]): form: sym.Symbol class_: str @@ -730,12 +730,7 @@ class MaybeHostForm(Node[sym.Symbol]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s( - auto_attribs=True, - eq=True, - frozen=True, - slots=True, -) +@attr.frozen(eq=True) class PyDict(Node[dict]): form: dict keys: Iterable[Node] @@ -747,12 +742,7 @@ class PyDict(Node[dict]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s( - auto_attribs=True, - eq=True, - frozen=True, - slots=True, -) +@attr.frozen(eq=True) class PyList(Node[list]): form: list items: Iterable[Node] @@ -763,12 +753,7 @@ class PyList(Node[list]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s( - auto_attribs=True, - eq=True, - frozen=True, - slots=True, -) +@attr.frozen(eq=True) class PySet(Node[Union[frozenset, set]]): form: Union[frozenset, set] items: Iterable[Node] @@ -779,7 +764,7 @@ class PySet(Node[Union[frozenset, set]]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class PyTuple(Node[tuple]): form: tuple items: Iterable[Node] @@ -790,7 +775,7 @@ class PyTuple(Node[tuple]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Queue(Node[lqueue.PersistentQueue]): form: lqueue.PersistentQueue items: Iterable[Node] @@ -801,7 +786,7 @@ class Queue(Node[lqueue.PersistentQueue]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Quote(Node[SpecialForm]): form: SpecialForm expr: Const @@ -813,7 +798,7 @@ class Quote(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Recur(Node[SpecialForm]): form: SpecialForm exprs: Iterable[Node] @@ -825,7 +810,7 @@ class Recur(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Reify(Node[SpecialForm]): form: SpecialForm interfaces: Iterable[DefTypeBase] @@ -844,7 +829,7 @@ def python_member_names(self) -> Iterable[str]: yield from deftype_or_reify_python_member_names(self.members) -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class RequireAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): form: Union[sym.Symbol, vec.PersistentVector] name: str @@ -856,7 +841,7 @@ class RequireAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Require(Node[SpecialForm]): form: SpecialForm aliases: Iterable[RequireAlias] @@ -867,7 +852,7 @@ class Require(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Set(Node[IPersistentSet]): form: IPersistentSet items: Iterable[Node] @@ -878,7 +863,7 @@ class Set(Node[IPersistentSet]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class SetBang(Node[SpecialForm]): form: SpecialForm target: Union[Assignable, Node] @@ -890,7 +875,7 @@ class SetBang(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Throw(Node[SpecialForm]): form: SpecialForm exception: Node @@ -901,7 +886,7 @@ class Throw(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Try(Node[SpecialForm]): form: SpecialForm body: Do @@ -914,7 +899,7 @@ class Try(Node[SpecialForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class VarRef(Node[sym.Symbol], Assignable): form: sym.Symbol var: Var @@ -927,7 +912,7 @@ class VarRef(Node[sym.Symbol], Assignable): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Vector(Node[IPersistentVector]): form: IPersistentVector items: Iterable[Node] @@ -938,11 +923,14 @@ class Vector(Node[IPersistentVector]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) -class WithMeta(Node[LispForm]): +T_withmeta = TypeVar("T_withmeta", Fn, Map, Queue, Reify, Set, Vector) + + +@attr.frozen +class WithMeta(Node[LispForm], Generic[T_withmeta]): form: LispForm meta: Union[Const, Map] - expr: Union[Fn, Map, Queue, Set, Vector] + expr: T_withmeta env: NodeEnv children: Sequence[kw.Keyword] = vec.v(META, EXPR) op: NodeOp = NodeOp.WITH_META @@ -950,7 +938,7 @@ class WithMeta(Node[LispForm]): raw_forms: IPersistentVector[LispForm] = vec.PersistentVector.empty() -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Yield(Node[SpecialForm]): form: SpecialForm expr: Optional[Node] diff --git a/src/basilisp/lang/delay.py b/src/basilisp/lang/delay.py index f769fab78..54e920bb2 100644 --- a/src/basilisp/lang/delay.py +++ b/src/basilisp/lang/delay.py @@ -8,17 +8,11 @@ T = TypeVar("T") -# Use attrs `these` for now as there is an open bug around slotted -# generic classes: https://github.com/python-attrs/attrs/issues/313 -@attr.s( - auto_attribs=True, - frozen=True, - these={"f": attr.ib(), "value": attr.ib(), "computed": attr.ib(default=False)}, -) +@attr.frozen class _DelayState(Generic[T]): f: Callable[[], T] value: Optional[T] - computed: bool + computed: bool = False class Delay(IDeref[T]): diff --git a/src/basilisp/lang/exception.py b/src/basilisp/lang/exception.py index 046fb2cc5..294b5d6d7 100644 --- a/src/basilisp/lang/exception.py +++ b/src/basilisp/lang/exception.py @@ -4,13 +4,7 @@ from basilisp.lang.obj import lrepr -@attr.s( - auto_attribs=True, - eq=True, - repr=False, - slots=True, - str=False, -) +@attr.define(repr=False, str=False) class ExceptionInfo(IExceptionInfo): message: str data: IPersistentMap diff --git a/src/basilisp/lang/futures.py b/src/basilisp/lang/futures.py index 0009835b5..4417ffabb 100644 --- a/src/basilisp/lang/futures.py +++ b/src/basilisp/lang/futures.py @@ -5,19 +5,15 @@ from typing import Callable, Optional, TypeVar import attr +from typing_extensions import ParamSpec from basilisp.lang.interfaces import IBlockingDeref T = TypeVar("T") +P = ParamSpec("P") -@attr.s( - auto_attribs=True, - eq=True, - frozen=True, - repr=False, - slots=True, -) +@attr.frozen(eq=True, repr=False) class Future(IBlockingDeref[T]): _future: "_Future[T]" @@ -61,8 +57,8 @@ def __init__(self, max_workers: Optional[int] = None): super().__init__(max_workers=max_workers) # pylint: disable=arguments-differ - def submit( # type: ignore - self, fn: Callable[..., T], *args, **kwargs + def submit( # type: ignore[override] + self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> "Future[T]": return Future(super().submit(fn, *args, **kwargs)) @@ -76,7 +72,7 @@ def __init__( super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix) # pylint: disable=arguments-differ - def submit( # type: ignore - self, fn: Callable[..., T], *args, **kwargs + def submit( # type: ignore[override] + self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> "Future[T]": return Future(super().submit(fn, *args, **kwargs)) diff --git a/src/basilisp/lang/interfaces.py b/src/basilisp/lang/interfaces.py index 99e147827..ffb51433f 100644 --- a/src/basilisp/lang/interfaces.py +++ b/src/basilisp/lang/interfaces.py @@ -168,8 +168,8 @@ def rseq(self) -> "ISeq[T]": class ISeqable(Iterable[T]): """ISeqable types can produce sequences of their elements, but are not ISeqs. - All of the builtin collections are ISeqable, except Lists which directly - implement ISeq. Values of type ISeqable respond True to the `seqable?` predicate.""" + All the builtin collections are ISeqable, except Lists which directly implement + ISeq. Values of type ISeqable respond True to the `seqable?` predicate.""" __slots__ = () @@ -421,9 +421,9 @@ def _record_lrepr(self, kwargs: Mapping) -> str: raise NotImplementedError() -def seq_equals(s1, s2) -> bool: - """Return True if two sequences contain the exactly the same elements in the - same order. Return False if one sequence is shorter than the other.""" +def seq_equals(s1: Union["ISeq", ISequential], s2: Any) -> bool: + """Return True if two sequences contain exactly the same elements in the same + order. Return False if one sequence is shorter than the other.""" assert isinstance(s1, (ISeq, ISequential)) if not isinstance(s2, (ISeq, ISequential)): diff --git a/src/basilisp/lang/multifn.py b/src/basilisp/lang/multifn.py index 2815a6abd..e0bdee61c 100644 --- a/src/basilisp/lang/multifn.py +++ b/src/basilisp/lang/multifn.py @@ -1,6 +1,8 @@ import threading from typing import Any, Callable, Generic, Optional, TypeVar +from typing_extensions import Concatenate, ParamSpec + from basilisp.lang import map as lmap from basilisp.lang import runtime from basilisp.lang import symbol as sym @@ -8,15 +10,16 @@ from basilisp.lang.set import PersistentSet T = TypeVar("T") -DispatchFunction = Callable[..., T] -Method = Callable[..., Any] +P = ParamSpec("P") +DispatchFunction = Callable[Concatenate[T, P], T] +Method = Callable[Concatenate[T, P], Any] _GLOBAL_HIERARCHY_SYM = sym.symbol("global-hierarchy", ns=runtime.CORE_NS) _ISA_SYM = sym.symbol("isa?", ns=runtime.CORE_NS) -class MultiFunction(Generic[T]): +class MultiFunction(Generic[T, P]): __slots__ = ( "_name", "_default", @@ -33,7 +36,7 @@ class MultiFunction(Generic[T]): def __init__( self, name: sym.Symbol, - dispatch: DispatchFunction, + dispatch: DispatchFunction[T, P], default: T, hierarchy: Optional[IRef] = None, ) -> None: @@ -63,11 +66,11 @@ def __init__( # caches. self._cached_hierarchy = self._hierarchy.deref() - def __call__(self, *args, **kwargs): - key = self._dispatch(*args, **kwargs) + def __call__(self, v: T, *args: P.args, **kwargs: P.kwargs) -> Any: + key = self._dispatch(v, *args, **kwargs) method = self.get_method(key) if method is not None: - return method(*args, **kwargs) + return method(v, *args, **kwargs) raise NotImplementedError def _reset_cache(self): @@ -94,14 +97,14 @@ def _precedes(self, tag: T, parent: T) -> bool: selection.""" return self._has_preference(tag, parent) or self._is_a(tag, parent) - def add_method(self, key: T, method: Method) -> None: + def add_method(self, key: T, method: Method[T, P]) -> None: """Add a new method to this function which will respond for key returned from the dispatch function.""" with self._lock: self._methods = self._methods.assoc(key, method) self._reset_cache() - def _find_and_cache_method(self, key: T) -> Optional[Method]: + def _find_and_cache_method(self, key: T) -> Optional[Method[T, P]]: """Find and cache the best method for dispatch value `key`.""" with self._lock: best_key: Optional[T] = None @@ -125,7 +128,7 @@ def _find_and_cache_method(self, key: T) -> Optional[Method]: return best_method - def get_method(self, key: T) -> Optional[Method]: + def get_method(self, key: T) -> Optional[Method[T, P]]: """Return the method which would handle this dispatch key or None if no method defined for this key and no default.""" if self._cached_hierarchy != self._hierarchy.deref(): @@ -159,7 +162,7 @@ def prefers(self): """Return a mapping of preferred values to the set of other values.""" return self._prefers - def remove_method(self, key: T) -> Optional[Method]: + def remove_method(self, key: T) -> Optional[Method[T, P]]: """Remove the method defined for this key and return it.""" with self._lock: method = self._methods.val_at(key, None) @@ -179,5 +182,5 @@ def default(self) -> T: return self._default @property - def methods(self) -> IPersistentMap[T, Method]: + def methods(self) -> IPersistentMap[T, Method[T, P]]: return self._methods diff --git a/src/basilisp/lang/obj.py b/src/basilisp/lang/obj.py index cdb69ed6c..98d4a2261 100644 --- a/src/basilisp/lang/obj.py +++ b/src/basilisp/lang/obj.py @@ -6,11 +6,10 @@ from decimal import Decimal from fractions import Fraction from functools import singledispatch +from itertools import islice from pathlib import Path from typing import Any, Callable, Iterable, Pattern, Tuple, Union -from basilisp.util import take - PrintCountSetting = Union[bool, int, None] SURPASSED_PRINT_LENGTH = "..." @@ -93,7 +92,7 @@ def entry_reprs(): print_dup = kwargs["print_dup"] print_length = kwargs["print_length"] if not print_dup and isinstance(print_length, int): - items = list(take(entry_reprs(), print_length + 1)) + items = list(islice(entry_reprs(), print_length + 1)) if len(items) > print_length: items.pop() trailer.append(SURPASSED_PRINT_LENGTH) @@ -125,7 +124,7 @@ def seq_lrepr( print_dup = kwargs["print_dup"] print_length = kwargs["print_length"] if not print_dup and isinstance(print_length, int): - items = list(take(iterable, print_length + 1)) + items = list(islice(iterable, print_length + 1)) if len(items) > print_length: items.pop() trailer.append(SURPASSED_PRINT_LENGTH) diff --git a/src/basilisp/lang/promise.py b/src/basilisp/lang/promise.py index 95a8f6d66..c2ee978f8 100644 --- a/src/basilisp/lang/promise.py +++ b/src/basilisp/lang/promise.py @@ -14,7 +14,7 @@ def __init__(self) -> None: self._is_delivered = False self._value: Optional[T] = None - def deliver(self, value: T): + def deliver(self, value: T) -> None: with self._condition: if not self._is_delivered: self._is_delivered = True diff --git a/src/basilisp/lang/reader.py b/src/basilisp/lang/reader.py index c039d39e9..79434d657 100644 --- a/src/basilisp/lang/reader.py +++ b/src/basilisp/lang/reader.py @@ -7,7 +7,6 @@ import io import re import uuid -from collections.abc import Hashable from datetime import datetime from fractions import Fraction from itertools import chain @@ -22,6 +21,7 @@ MutableMapping, Optional, Pattern, + Sequence, Set, Tuple, TypeVar, @@ -117,7 +117,7 @@ class Comment: # pylint:disable=redefined-builtin -@attr.s(auto_attribs=True, repr=False, slots=True, str=False) +@attr.define(repr=False, str=False) class SyntaxError(Exception): message: str line: Optional[int] = None @@ -418,7 +418,9 @@ def _compile_feature_vec(form: IPersistentList[Tuple[kw.Keyword, ReaderForm]]): feature_list: List[Tuple[kw.Keyword, ReaderForm]] = [] try: - for k, v in partition(form, 2): + for k, v in partition( + cast(Sequence[Tuple[kw.Keyword, ReaderForm]], form), 2 + ): if not isinstance(k, kw.Keyword): raise SyntaxError( f"Reader conditional features must be keywords, not {type(k)}" @@ -661,7 +663,7 @@ def __read_map_elems(ctx: ReaderContext) -> Iterable[RawReaderForm]: def _map_key_processor( namespace: Optional[str], -) -> Callable[[Hashable], Hashable]: +) -> Callable[[Any], Any]: """Return a map key processor. If no `namespace` is provided, return an identity function. If a `namespace` @@ -700,9 +702,13 @@ def _read_map( # pylint: disable=redefined-loop-name for k, v in partition(list(__read_map_elems(ctx)), 2): k = process_key(k) - if k in d: - raise ctx.syntax_error(f"Duplicate key '{k}' in map literal") - d[k] = v + try: + if k in d: + raise ctx.syntax_error(f"Duplicate key '{k}' in map literal") # type: ignore[str-bytes-safe] + except TypeError as e: + raise ctx.syntax_error("Map keys must be hashable") from e + else: + d[k] = v except ValueError as e: raise ctx.syntax_error("Unexpected token '}'; expected map value") from e else: diff --git a/src/basilisp/lang/reduced.py b/src/basilisp/lang/reduced.py index b9bfed905..bb546830f 100644 --- a/src/basilisp/lang/reduced.py +++ b/src/basilisp/lang/reduced.py @@ -7,7 +7,7 @@ T = TypeVar("T") -@attr.s(auto_attribs=True, frozen=True, slots=True) +@attr.frozen class Reduced(IDeref[T]): value: T diff --git a/src/basilisp/lang/reference.py b/src/basilisp/lang/reference.py index 1d4b38ac0..21cf7408e 100644 --- a/src/basilisp/lang/reference.py +++ b/src/basilisp/lang/reference.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Optional, TypeVar from readerwriterlock.rwlock import RWLockable +from typing_extensions import Concatenate, ParamSpec from basilisp.lang import keyword as kw from basilisp.lang import map as lmap @@ -14,17 +15,8 @@ RefWatchKey, ) -try: - from typing import Protocol -except ImportError: - AlterMeta = Callable[..., Optional[IPersistentMap]] -else: - - class AlterMeta(Protocol): # type: ignore [no-redef] - def __call__( - self, meta: Optional[IPersistentMap], *args - ) -> Optional[IPersistentMap]: - ... +P = ParamSpec("P") +AlterMeta = Callable[Concatenate[Optional[IPersistentMap], P], Optional[IPersistentMap]] class ReferenceBase(IReference): @@ -75,7 +67,7 @@ def add_watch(self, k: RefWatchKey, wf: RefWatcher) -> "RefBase[T]": self._watches = self._watches.assoc(k, wf) return self - def _notify_watches(self, old: Any, new: Any): + def _notify_watches(self, old: Any, new: Any) -> None: for k, wf in self._watches.items(): wf(k, self, old, new) @@ -101,7 +93,7 @@ def set_validator(self, vf: Optional[RefValidator] = None) -> None: self._validate(self.deref(), vf=vf) self._validator = vf - def _validate(self, val: Any, vf: Optional[RefValidator] = None): + def _validate(self, val: Any, vf: Optional[RefValidator] = None) -> None: vf = vf or self._validator if vf is not None: try: diff --git a/src/basilisp/lang/runtime.py b/src/basilisp/lang/runtime.py index cfc985551..37b2812fc 100644 --- a/src/basilisp/lang/runtime.py +++ b/src/basilisp/lang/runtime.py @@ -14,7 +14,7 @@ import sys import threading import types -from collections.abc import Sequence +from collections.abc import Sequence, Sized from fractions import Fraction from typing import ( AbstractSet, @@ -973,6 +973,7 @@ def pop_thread_bindings() -> None: ################### T = TypeVar("T") +X = TypeVar("X") @functools.singledispatch @@ -996,7 +997,7 @@ def _first_iseq(o: ISeq[T]) -> Optional[T]: @functools.singledispatch -def rest(o) -> ISeq: +def rest(o: Any) -> ISeq: """If o is a ISeq, return the elements after the first in o. If o is None, returns an empty seq. Otherwise, coerces o to a seq and returns the rest.""" n = to_seq(o) @@ -1052,7 +1053,7 @@ def _cons(seq, o) -> ISeq: @_cons.register(type(None)) -def _cons_none(_: None, o) -> ISeq: +def _cons_none(_: None, o: T) -> ISeq[T]: return llist.l(o) @@ -1071,12 +1072,9 @@ def cons(o, seq) -> ISeq: to_seq = lseq.to_seq -def concat(*seqs) -> ISeq: +def concat(*seqs: Any) -> ISeq: """Concatenate the sequences given by seqs into a single ISeq.""" - allseqs = lseq.sequence(itertools.chain.from_iterable(filter(None, seqs))) - if allseqs is None: - return lseq.EMPTY - return allseqs + return lseq.sequence(itertools.chain.from_iterable(filter(None, seqs))) def apply(f, args): @@ -1121,19 +1119,22 @@ def apply_kw(f, args): return f(*final, **kwargs) +@functools.singledispatch def count(coll) -> int: - if coll is None: - return 0 - else: - try: - return len(coll) - except (AttributeError, TypeError): - try: - return sum(1 for _ in coll) - except TypeError as e: - raise TypeError( - f"count not supported on object of type {type(coll)}" - ) from e + try: + return sum(1 for _ in coll) + except TypeError as e: + raise TypeError(f"count not supported on object of type {type(coll)}") from e + + +@count.register(type(None)) +def _count_none(_: None) -> int: + return 0 + + +@count.register(Sized) +def _count_sized(coll: Sized): + return len(coll) __nth_sentinel = object() @@ -1342,7 +1343,7 @@ def _divide_ints(x: int, y: LispNumber) -> LispNumber: return x / y -def quotient(num, div) -> LispNumber: +def quotient(num: LispNumber, div: LispNumber) -> LispNumber: """Return the integral quotient resulting from the division of num by div.""" return math.trunc(num / div) @@ -1404,7 +1405,7 @@ def _fn_to_comparator(f): def cmp(x, y): r = f(x, y) - if isinstance(r, numbers.Number) and not isinstance(r, bool): + if not isinstance(r, bool) and isinstance(r, numbers.Number): return r elif r: return -1 diff --git a/src/basilisp/lang/seq.py b/src/basilisp/lang/seq.py index 6a16612aa..6942ae72c 100644 --- a/src/basilisp/lang/seq.py +++ b/src/basilisp/lang/seq.py @@ -1,5 +1,6 @@ import functools -from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar +import typing +from typing import Callable, Iterable, Iterator, Optional, TypeVar from basilisp.lang.interfaces import ( IPersistentMap, @@ -249,7 +250,7 @@ def is_realized(self): return self._gen is None -def sequence(s: Iterable) -> ISeq[Any]: +def sequence(s: Iterable[T]) -> ISeq[T]: """Create a Sequence from Iterable s.""" try: i = iter(s) @@ -258,7 +259,17 @@ def sequence(s: Iterable) -> ISeq[Any]: return EMPTY -def _seq_or_nil(s: Optional[ISeq]) -> Optional[ISeq]: +@typing.overload +def _seq_or_nil(s: None) -> None: + ... + + +@typing.overload +def _seq_or_nil(s: ISeq) -> Optional[ISeq]: + ... + + +def _seq_or_nil(s): """Return None if a ISeq is empty, the ISeq otherwise.""" if s is None or s.is_empty: return None diff --git a/src/basilisp/lang/symbol.py b/src/basilisp/lang/symbol.py index a32882ceb..52c550f2b 100644 --- a/src/basilisp/lang/symbol.py +++ b/src/basilisp/lang/symbol.py @@ -54,7 +54,7 @@ def with_meta(self, meta: Optional[IPersistentMap]) -> "Symbol": def as_python_sym(self) -> str: if self.ns is not None: return f"{munge(self.ns)}.{munge(self.name)}" - return f"{munge(self.name)}" + return munge(self.name) def __eq__(self, other): if not isinstance(other, Symbol): @@ -86,6 +86,8 @@ def __call__(self, m: Union[IAssociative, IPersistentSet], default=None): return None -def symbol(name: str, ns: Optional[str] = None, meta=None) -> Symbol: +def symbol( + name: str, ns: Optional[str] = None, meta: Optional[IPersistentMap] = None +) -> Symbol: """Create a new symbol.""" return Symbol(name, ns=ns, meta=meta) diff --git a/src/basilisp/lang/vector.py b/src/basilisp/lang/vector.py index 23262405c..809a10b2d 100644 --- a/src/basilisp/lang/vector.py +++ b/src/basilisp/lang/vector.py @@ -1,5 +1,5 @@ from functools import total_ordering -from typing import Iterable, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Iterable, Optional, Sequence, TypeVar, Union, cast from pyrsistent import PVector, pvector # noqa # pylint: disable=unused-import from pyrsistent.typing import PVectorEvolver @@ -19,6 +19,9 @@ from basilisp.lang.seq import sequence from basilisp.util import partition +if TYPE_CHECKING: + from typing import Tuple + T = TypeVar("T") @@ -46,7 +49,7 @@ def cons_transient(self, *elems: T) -> "TransientVector[T]": # type: ignore[ove return self def assoc_transient(self, *kvs: T) -> "TransientVector[T]": - for i, v in partition(kvs, 2): + for i, v in cast("Sequence[Tuple[int, T]]", partition(kvs, 2)): self._inner.set(i, v) return self diff --git a/src/basilisp/lang/volatile.py b/src/basilisp/lang/volatile.py index 7c78bef65..858abfba6 100644 --- a/src/basilisp/lang/volatile.py +++ b/src/basilisp/lang/volatile.py @@ -1,13 +1,15 @@ from typing import Callable, Optional, TypeVar import attr +from typing_extensions import Concatenate, ParamSpec from basilisp.lang.interfaces import IDeref T = TypeVar("T") +P = ParamSpec("P") -@attr.s(auto_attribs=True, slots=True, these={"value": attr.ib()}) +@attr.define class Volatile(IDeref[T]): """A volatile reference container. Volatile references do not provide atomic semantics, but they may be useful as a mutable reference container in a @@ -22,6 +24,8 @@ def reset(self, v: T) -> T: self.value = v return self.value - def swap(self, f: Callable[..., T], *args, **kwargs) -> T: + def swap( + self, f: Callable[Concatenate[T, P], T], *args: P.args, **kwargs: P.kwargs + ) -> T: self.value = f(self.value, *args, **kwargs) return self.value diff --git a/src/basilisp/util.py b/src/basilisp/util.py index 035ad4b1d..eaaf8baa5 100644 --- a/src/basilisp/util.py +++ b/src/basilisp/util.py @@ -1,7 +1,8 @@ import contextlib import time -from itertools import islice -from typing import Callable, Generic, Optional, TypeVar +from typing import Callable, Generic, Optional, Tuple, TypeVar + +from typing_extensions import Iterable, Sequence @contextlib.contextmanager @@ -66,8 +67,8 @@ def is_present(self) -> bool: return self._inner is not None -def partition(coll, n: int): - """Partition coll into groups of size n.""" +def partition(coll: Sequence[T], n: int) -> Iterable[Tuple[T, ...]]: + """Partition `coll` into groups of size `n`.""" assert n > 0 start = 0 stop = n @@ -78,9 +79,3 @@ def partition(coll, n: int): if start < len(coll) < stop: stop = len(coll) yield tuple(e for e in coll[start:stop]) - - -def take(coll, n: int): - """Yield the first n elements of coll.""" - assert n >= 0 - yield from islice(coll, n) diff --git a/tests/basilisp/reader_test.py b/tests/basilisp/reader_test.py index 0b8ffbf06..9f1bcf250 100644 --- a/tests/basilisp/reader_test.py +++ b/tests/basilisp/reader_test.py @@ -573,6 +573,9 @@ def test_map(): with pytest.raises(reader.SyntaxError): read_str_first("{") + with pytest.raises(reader.SyntaxError): + read_str_first("{#py [] :some-keyword}") + assert read_str_first("{}") == lmap.map({}) assert read_str_first("{:a 1}") == lmap.map({kw.keyword("a"): 1}) assert read_str_first('{:a 1 :b "string"}') == lmap.map(