diff --git a/package/version b/package/version index 8509341ee..785120566 100644 --- a/package/version +++ b/package/version @@ -1 +1 @@ -0.1.359 +0.1.360 diff --git a/pyproject.toml b/pyproject.toml index 096598aae..0bcff45a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "pyk" -version = "0.1.359" +version = "0.1.360" description = "" authors = [ "Runtime Verification, Inc. ", diff --git a/src/pyk/kcfg/explore.py b/src/pyk/kcfg/explore.py index 7ff742763..abcc225ba 100644 --- a/src/pyk/kcfg/explore.py +++ b/src/pyk/kcfg/explore.py @@ -18,7 +18,7 @@ ) from ..kast.outer import KRule from ..konvert import krule_to_kore -from ..kore.rpc import KoreClient, KoreServer, SatResult, StopReason, UnknownResult, UnsatResult +from ..kore.rpc import SatResult, StopReason, UnknownResult, UnsatResult from ..kore.syntax import Import, Module from ..ktool.kprove import KoreExecLogFormat from ..prelude import k @@ -35,10 +35,9 @@ from ..kast import KInner from ..kast.outer import KClaim - from ..kore.rpc import LogEntry + from ..kore.rpc import KoreServerBase, LogEntry from ..kore.syntax import Sentence from ..ktool.kprint import KPrint - from ..utils import BugReport from .kcfg import NodeIdLike @@ -48,45 +47,28 @@ class KCFGExplore(ContextManager['KCFGExplore']): kprint: KPrint id: str - _port: int | None - _kore_rpc_command: str | Iterable[str] - _smt_timeout: int | None - _smt_retry_limit: int | None - _bug_report: BugReport | None - - _kore_server: KoreServer | None - _kore_client: KoreClient | None - _rpc_closed: bool + _trace_rewrites: bool + _kore_server: KoreServerBase + def __init__( self, + kore_server: KoreServerBase, kprint: KPrint, *, id: str | None = None, - port: int | None = None, - kore_rpc_command: str | Iterable[str] = 'kore-rpc', - smt_timeout: int | None = None, - smt_retry_limit: int | None = None, - bug_report: BugReport | None = None, haskell_log_format: KoreExecLogFormat = KoreExecLogFormat.ONELINE, haskell_log_entries: Iterable[str] = (), log_axioms_file: Path | None = None, trace_rewrites: bool = False, ): + self._kore_server = kore_server self.kprint = kprint self.id = id if id is not None else 'NO ID' - self._port = port - self._kore_rpc_command = kore_rpc_command - self._smt_timeout = smt_timeout - self._smt_retry_limit = smt_retry_limit - self._bug_report = bug_report self._haskell_log_format = haskell_log_format self._haskell_log_entries = haskell_log_entries self._log_axioms_file = log_axioms_file - self._kore_server = None - self._kore_client = None - self._rpc_closed = False self._trace_rewrites = trace_rewrites def __enter__(self) -> KCFGExplore: @@ -95,35 +77,8 @@ def __enter__(self) -> KCFGExplore: def __exit__(self, *args: Any) -> None: self.close() - @property - def _kore_rpc(self) -> tuple[KoreServer, KoreClient]: - if self._rpc_closed: - raise ValueError('RPC server already closed!') - if not self._kore_server: - self._kore_server = KoreServer( - self.kprint.definition_dir, - self.kprint.main_module, - port=self._port, - bug_report=self._bug_report, - command=self._kore_rpc_command, - smt_timeout=self._smt_timeout, - smt_retry_limit=self._smt_retry_limit, - haskell_log_format=self._haskell_log_format, - haskell_log_entries=self._haskell_log_entries, - log_axioms_file=self._log_axioms_file, - ) - if not self._kore_client: - self._kore_client = KoreClient('localhost', self._kore_server._port, bug_report=self._bug_report) - return (self._kore_server, self._kore_client) - def close(self) -> None: - self._rpc_closed = True - if self._kore_server is not None: - self._kore_server.close() - self._kore_server = None - if self._kore_client is not None: - self._kore_client.close() - self._kore_client = None + self.close() def cterm_execute( self, @@ -135,7 +90,7 @@ def cterm_execute( ) -> tuple[int, CTerm, list[CTerm], tuple[LogEntry, ...]]: _LOGGER.debug(f'Executing: {cterm}') kore = self.kprint.kast_to_kore(cterm.kast, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc er = kore_client.execute( kore, max_depth=depth, @@ -163,7 +118,7 @@ def cterm_execute( def cterm_simplify(self, cterm: CTerm) -> tuple[KInner, tuple[LogEntry, ...]]: _LOGGER.debug(f'Simplifying: {cterm}') kore = self.kprint.kast_to_kore(cterm.kast, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc kore_simplified, logs = kore_client.simplify(kore) kast_simplified = self.kprint.kore_to_kast(kore_simplified) return kast_simplified, logs @@ -171,7 +126,7 @@ def cterm_simplify(self, cterm: CTerm) -> tuple[KInner, tuple[LogEntry, ...]]: def kast_simplify(self, kast: KInner) -> tuple[KInner, tuple[LogEntry, ...]]: _LOGGER.debug(f'Simplifying: {kast}') kore = self.kprint.kast_to_kore(kast, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc kore_simplified, logs = kore_client.simplify(kore) kast_simplified = self.kprint.kore_to_kast(kore_simplified) return kast_simplified, logs @@ -179,7 +134,7 @@ def kast_simplify(self, kast: KInner) -> tuple[KInner, tuple[LogEntry, ...]]: def cterm_get_model(self, cterm: CTerm, module_name: str | None = None) -> Subst | None: _LOGGER.info(f'Getting model: {cterm}') kore = self.kprint.kast_to_kore(cterm.kast, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc result = kore_client.get_model(kore, module_name=module_name) if type(result) is UnknownResult: _LOGGER.debug('Result is Unknown') @@ -217,7 +172,7 @@ def cterm_implies( _consequent = KApply(KLabel('#Exists', [GENERATED_TOP_CELL]), [KVariable(uc), _consequent]) antecedent_kore = self.kprint.kast_to_kore(antecedent.kast, GENERATED_TOP_CELL) consequent_kore = self.kprint.kast_to_kore(_consequent, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc result = kore_client.implies(antecedent_kore, consequent_kore) if not result.satisfiable: if result.substitution is not None: @@ -326,7 +281,7 @@ def cterm_assume_defined(self, cterm: CTerm) -> CTerm: _LOGGER.debug(f'Computing definedness condition for: {cterm}') kast = KApply(KLabel('#Ceil', [GENERATED_TOP_CELL, GENERATED_TOP_CELL]), [cterm.config]) kore = self.kprint.kast_to_kore(kast, GENERATED_TOP_CELL) - _, kore_client = self._kore_rpc + _, kore_client = self._kore_server._kore_rpc kore_simplified, _logs = kore_client.simplify(kore) kast_simplified = self.kprint.kore_to_kast(kore_simplified) _LOGGER.debug(f'Definedness condition computed: {kast_simplified}') @@ -499,9 +454,9 @@ def add_dependencies_module( for c in dependencies ] kore_axioms: list[Sentence] = [krule_to_kore(self.kprint.kompiled_kore, r) for r in kast_rules] - _, kore_client = self._kore_rpc - sentences: list[Sentence] = [Import(module_name=old_module_name, attrs=())] - sentences = sentences + kore_axioms - m = Module(name=new_module_name, sentences=sentences) - _LOGGER.info(f'Adding dependencies module {self.id}: {new_module_name}') - kore_client.add_module(m) + for _, kore_client in self._kore_server._kore_rpcs: + sentences: list[Sentence] = [Import(module_name=old_module_name, attrs=())] + sentences = sentences + kore_axioms + m = Module(name=new_module_name, sentences=sentences) + _LOGGER.info(f'Adding dependencies module {self.id}: {new_module_name}') + kore_client.add_module(m) diff --git a/src/pyk/kore/rpc.py b/src/pyk/kore/rpc.py index 55b3e29af..2d26927ee 100644 --- a/src/pyk/kore/rpc.py +++ b/src/pyk/kore/rpc.py @@ -3,6 +3,7 @@ import json import logging import socket +import threading from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta @@ -23,6 +24,7 @@ from collections.abc import Iterable, Mapping from typing import Any, ClassVar, Final, TextIO, TypeVar + from ..ktool.kprint import KPrint from ..utils import BugReport from .syntax import Module @@ -94,6 +96,7 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self._file.close() + self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() def request(self, method: str, **params: Any) -> dict[str, Any]: @@ -529,6 +532,7 @@ class KoreClient(ContextManager['KoreClient']): def __init__(self, host: str, port: int, *, timeout: int | None = None, bug_report: BugReport | None = None): self._client = JsonRpcClient(host, port, timeout=timeout, bug_report=bug_report) + self._lock = threading.Lock() def __enter__(self) -> KoreClient: return self @@ -539,6 +543,9 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self._client.close() + def _try_release(self) -> None: + ... + def _request(self, method: str, **params: Any) -> dict[str, Any]: try: return self._client.request(method, **params) @@ -637,6 +644,30 @@ def add_module(self, module: Module) -> None: assert result == [] +class ParKoreClient(KoreClient): + _lock: threading.Lock + + def __init__(self, host: str, port: int, *, timeout: int | None = None, bug_report: BugReport | None = None): + super().__init__(host, port, timeout=timeout, bug_report=bug_report) + self._lock = threading.Lock() + + def _try_release(self) -> None: + try: + self._lock.release() + except RuntimeError: + pass + + def _request(self, method: str, **params: Any) -> dict[str, Any]: + try: + res = self._client.request(method, **params) + self._try_release() + return res + except JsonRpcError as err: + self._try_release() + assert err.code not in {-32601, -32602}, 'Malformed Kore-RPC request' + raise KoreClientError(message=err.message, code=err.code, data=err.data) from err + + class KoreServer(ContextManager['KoreServer']): _proc: Popen _pid: int @@ -743,3 +774,208 @@ def close(self) -> None: self._proc.send_signal(SIGINT) self._proc.wait() _LOGGER.info(f'KoreServer stopped: {self.host}:{self.port}, pid={self.pid}') + + +class KoreServerBase(ABC): + kprint: KPrint + _port: int | None + _kore_rpc_command: str | Iterable[str] + _smt_timeout: int | None + _smt_retry_limit: int | None + _bug_report: BugReport | None + + _rpc_closed: bool + + def __init__( + self, + kprint: KPrint, + port: int | None = None, + kore_rpc_command: str | Iterable[str] = 'kore-rpc', + smt_timeout: int | None = None, + smt_retry_limit: int | None = None, + bug_report: BugReport | None = None, + haskell_log_format: KoreExecLogFormat = KoreExecLogFormat.ONELINE, + haskell_log_entries: Iterable[str] = (), + log_axioms_file: Path | None = None, + ): + self.kprint = kprint + self._port = port + self._kore_rpc_command = kore_rpc_command + self._smt_timeout = smt_timeout + self._smt_retry_limit = smt_retry_limit + self._bug_report = bug_report + self._haskell_log_format = haskell_log_format + self._haskell_log_entries = haskell_log_entries + self._log_axioms_file = log_axioms_file + self._rpc_closed = False + + @property + @abstractmethod + def _kore_rpc(self) -> tuple[KoreServer, KoreClient]: + ... + + @property + @abstractmethod + def _kore_rpcs(self) -> list[tuple[KoreServer, KoreClient]]: + ... + + def __exit__(self, *args: Any) -> None: + self.close() + + @abstractmethod + def close(self) -> None: + ... + + +class SingleKoreServer(KoreServerBase): + _kore_server: KoreServer | None = None + _kore_client: KoreClient | None = None + + @property + def _kore_rpc(self) -> tuple[KoreServer, KoreClient]: + if self._rpc_closed: + raise ValueError('RPC server already closed!') + if not self._kore_server: + self._kore_server = KoreServer( + self.kprint.definition_dir, + self.kprint.main_module, + port=self._port, + bug_report=self._bug_report, + command=self._kore_rpc_command, + smt_timeout=self._smt_timeout, + smt_retry_limit=self._smt_retry_limit, + haskell_log_format=self._haskell_log_format, + haskell_log_entries=self._haskell_log_entries, + log_axioms_file=self._log_axioms_file, + ) + if not self._kore_client: + self._kore_client = KoreClient('localhost', self._kore_server._port, bug_report=self._bug_report) + return (self._kore_server, self._kore_client) + + @property + def _kore_rpcs(self) -> list[tuple[KoreServer, KoreClient]]: + return [self._kore_rpc] + + def close(self) -> None: + self._rpc_closed = True + if self._kore_server is not None: + self._kore_server.close() + self._kore_server = None + if self._kore_client is not None: + self._kore_client.close() + self._kore_client = None + + +class KoreServerPool(KoreServerBase): + _kore_server: list[KoreServer] = [] + _kore_client: list[KoreClient] = [] + + _max_clients: int + + def __init__( + self, + kprint: KPrint, + port: int | None = None, + kore_rpc_command: str | Iterable[str] = 'kore-rpc', + smt_timeout: int | None = None, + smt_retry_limit: int | None = None, + bug_report: BugReport | None = None, + haskell_log_format: KoreExecLogFormat = KoreExecLogFormat.ONELINE, + haskell_log_entries: Iterable[str] = (), + log_axioms_file: Path | None = None, + max_clients: int = 1, + ): + super().__init__( + kprint, + port, + kore_rpc_command, + smt_timeout, + smt_retry_limit, + bug_report, + haskell_log_format, + haskell_log_entries, + log_axioms_file, + ) + self._max_clients = max_clients + + @property + def _kore_rpc(self) -> tuple[KoreServer, KoreClient]: + if self._rpc_closed: + raise ValueError('RPC server already closed!') + curr_server = self._curr_server() # need interim because of lock + if curr_server is not None: + (server, client) = curr_server + elif len(self._kore_client) < self._max_clients: + (server, client) = self._new_server() + client._lock.acquire() + self._kore_server.append(server) + self._kore_client.append(client) + else: + (server, client) = self._next_available_server() + return (server, client) + + @property + def _kore_rpcs(self) -> list[tuple[KoreServer, KoreClient]]: + if self._rpc_closed: + raise ValueError('RPC server already closed!') + res = self._all_servers() + new_servers = [] + while len(self._kore_server) < self._max_clients: + (server, client) = self._new_server() + client._lock.acquire() + self._kore_server.append(server) + self._kore_client.append(client) + new_servers.append((server, client)) + return res + new_servers + + def _curr_server(self) -> tuple[KoreServer, KoreClient] | None: + i, client = next( + ( + (i, client) + for i, client in enumerate(self._kore_client) + if self._kore_client[i]._lock.acquire(blocking=False) + ), + (0, None), + ) + if client is not None: + return (self._kore_server[i], self._kore_client[i]) + return None + + def _next_available_server(self) -> tuple[KoreServer, KoreClient]: + i = 0 + while True: + if self._kore_client[i]._lock.acquire(timeout=5): + return (self._kore_server[i], self._kore_client[i]) + i = (i + 1) % len(self._kore_client) + + def _all_servers(self) -> list[tuple[KoreServer, KoreClient]]: + acquired: list[tuple[KoreServer, KoreClient]] = [] + while len(acquired) < len(self._kore_server): + acquired.append(self._next_available_server()) + return acquired + + # TODO: don't use the same defined port for the KoreServer but a new one + def _new_server(self) -> tuple[KoreServer, KoreClient]: + server = KoreServer( + self.kprint.definition_dir, + self.kprint.main_module, + port=self._port, + bug_report=self._bug_report, + command=self._kore_rpc_command, + smt_timeout=self._smt_timeout, + smt_retry_limit=self._smt_retry_limit, + haskell_log_format=self._haskell_log_format, + haskell_log_entries=self._haskell_log_entries, + log_axioms_file=self._log_axioms_file, + ) + client = ParKoreClient('localhost', server._port, bug_report=self._bug_report) + return (server, client) + + def close(self) -> None: + self._rpc_closed = True + while self._kore_server: + server = self._kore_server.pop() + server.close() + while self._kore_client: + client = self._kore_client.pop() + client.close() diff --git a/src/pyk/proof/reachability.py b/src/pyk/proof/reachability.py index 3b7a78317..dbf3f4f81 100644 --- a/src/pyk/proof/reachability.py +++ b/src/pyk/proof/reachability.py @@ -4,6 +4,7 @@ import logging from dataclasses import dataclass from itertools import chain +from multiprocessing.pool import ThreadPool as Pool from typing import TYPE_CHECKING from pyk.kore.rpc import LogEntry @@ -477,26 +478,20 @@ def advance_proof( cut_point_rules: Iterable[str] = (), terminal_rules: Iterable[str] = (), implication_every_block: bool = True, + max_workers: int = 1, ) -> KCFG: iterations = 0 - while self.proof.pending: - self.proof.write_proof() - - if max_iterations is not None and max_iterations <= iterations: - _LOGGER.warning(f'Reached iteration bound {self.proof.id}: {max_iterations}') - break - iterations += 1 - curr_node = self.proof.pending[0] - + def _advance_from_node(node: NodeIdLike) -> None: + curr_node = self.proof.kcfg.node(node) if self._check_subsume(curr_node): - continue + return if self._check_terminal(curr_node): - continue + return if self._check_abstract(curr_node): - continue + return if self._extract_branches is not None and len(self.proof.kcfg.splits(target_id=curr_node.id)) == 0: branches = list(self._extract_branches(curr_node.cterm)) @@ -505,7 +500,7 @@ def advance_proof( _LOGGER.info( f'Found {len(branches)} branches using heuristic for node {self.proof.id}: {shorten_hashes(curr_node.id)}: {[self.kcfg_explore.kprint.pretty_print(bc) for bc in branches]}' ) - continue + return module_name = ( self.circularities_module_name if self.nonzero_depth(curr_node) else self.dependencies_module_name @@ -520,6 +515,19 @@ def advance_proof( module_name=module_name, ) + while self.proof.pending: + self.proof.write_proof() + + if max_iterations is not None and max_iterations <= iterations: + _LOGGER.warning(f'Reached iteration bound {self.proof.id}: {max_iterations}') + break + iterations += 1 + + curr_nodes = [node.id for ni, node in enumerate(self.proof.pending) if ni < max_workers] + pool = Pool(processes=len(curr_nodes)) + res = pool.map_async(_advance_from_node, curr_nodes) + res.wait() + self.proof.write_proof() return self.proof.kcfg @@ -681,6 +689,7 @@ def advance_proof( cut_point_rules: Iterable[str] = (), terminal_rules: Iterable[str] = (), implication_every_block: bool = True, + max_workers: int = 1, ) -> KCFG: iterations = 0 @@ -718,6 +727,7 @@ def advance_proof( cut_point_rules=cut_point_rules, terminal_rules=terminal_rules, implication_every_block=implication_every_block, + max_workers=max_workers, ) self.proof.write_proof() diff --git a/src/pyk/testing/_kompiler.py b/src/pyk/testing/_kompiler.py index e6da48281..d0859f65e 100644 --- a/src/pyk/testing/_kompiler.py +++ b/src/pyk/testing/_kompiler.py @@ -8,7 +8,7 @@ from ..kcfg import KCFGExplore from ..kllvm.compiler import compile_runtime from ..kllvm.importer import import_runtime -from ..kore.rpc import KoreClient, KoreServer +from ..kore.rpc import KoreClient, KoreServer, SingleKoreServer from ..ktool.kompile import Kompile from ..ktool.kprint import KPrint from ..ktool.kprove import KProve @@ -128,9 +128,10 @@ def _update_symbol_table(symbol_table: SymbolTable) -> None: class KCFGExploreTest(KProveTest): @pytest.fixture def kcfg_explore(self, kprove: KProve) -> Iterator[KCFGExplore]: + server = SingleKoreServer(kprove, bug_report=kprove._bug_report) with KCFGExplore( + server, kprove, - bug_report=kprove._bug_report, ) as kcfg_explore: yield kcfg_explore