diff --git a/mypy.ini b/mypy.ini index 9c05db2..5bb8894 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,8 @@ python_version = 3.8 color_output = True error_summary = True -; disallow_untyped_calls = True +check_untyped_defs = true +disallow_untyped_calls = True ; disallow_untyped_defs = True ; disallow_any_generics = True ; disallow_any_unimported = True diff --git a/src/vagrant/__init__.py b/src/vagrant/__init__.py index 6b8282e..299a156 100644 --- a/src/vagrant/__init__.py +++ b/src/vagrant/__init__.py @@ -19,6 +19,8 @@ import subprocess import sys import logging +import typing +from typing import Dict, Iterator, List, Optional, Union # local from . import compat @@ -42,7 +44,7 @@ ) -def which(program): # noqa C901 +def which(program) -> Optional[str]: # noqa C901 """ Emulate unix 'which' command. If program is a path to an executable file (i.e. it contains any directory components, like './myscript'), return @@ -62,7 +64,7 @@ def which(program): # noqa C901 https://hg.python.org/cpython/file/default/Lib/shutil.py """ - def is_exe(fpath): + def is_exe(fpath) -> bool: return os.path.isfile(fpath) and os.access(fpath, os.X_OK) # Shortcut: If program contains any dir components, do not search the path @@ -128,7 +130,7 @@ def is_exe(fpath): # The full path to the vagrant executable, e.g. '/usr/bin/vagrant' -def get_vagrant_executable(): +def get_vagrant_executable() -> Optional[str]: return which("vagrant") @@ -174,7 +176,7 @@ def none_cm(): yield None -def make_file_cm(filename, mode="a"): +def make_file_cm(filename, mode="a") -> typing.Callable[[], typing.ContextManager]: """ Open a file for appending and yield the open filehandle. Close the filehandle after yielding it. This is useful for creating a context @@ -230,7 +232,7 @@ def __init__( env=None, out_cm=None, err_cm=None, - ): + ) -> None: """ root: a directory containing a file named Vagrantfile. Defaults to os.getcwd(). This is the directory and Vagrantfile that the Vagrant @@ -258,8 +260,8 @@ def __init__( will be sent to devnull. """ self.root = os.path.abspath(root) if root is not None else os.getcwd() - self._cached_conf = {} - self._vagrant_exe = None # cache vagrant executable path + self._cached_conf: Dict[str, Optional[Dict[str, str]]] = {} + self._vagrant_exe: Optional[str] = None # cache vagrant executable path self.env = env if out_cm is not None: self.out_cm = out_cm @@ -279,7 +281,7 @@ def __init__( else: self.err_cm = none_cm - def version(self): + def version(self) -> str: """ Return the installed vagrant version, as a string, e.g. '1.5.0' """ @@ -291,7 +293,7 @@ def version(self): ) return m.group("version") - def init(self, box_name=None, box_url=None): + def init(self, box_name=None, box_url=None) -> None: """ From the Vagrant docs: @@ -315,7 +317,7 @@ def up( provision=None, provision_with=None, stream_output=False, - ): + ) -> Optional[Iterator[str]]: """ Invoke `vagrant up` to start a box or boxes, possibly streaming the command output. @@ -368,7 +370,7 @@ def up( self._cached_conf[vm_name] = None # remove cached configuration return generator if stream_output else None - def provision(self, vm_name=None, provision_with=None): + def provision(self, vm_name=None, provision_with=None) -> None: """ Runs the provisioners defined in the Vagrantfile. vm_name: optional VM name string. @@ -381,7 +383,7 @@ def provision(self, vm_name=None, provision_with=None): def reload( self, vm_name=None, provision=None, provision_with=None, stream_output=False - ): + ) -> Optional[Iterator[str]]: """ Quoting from Vagrant docs: > The equivalent of running a halt followed by an up. @@ -421,21 +423,21 @@ def reload( self._cached_conf[vm_name] = None # remove cached configuration return generator if stream_output else None - def suspend(self, vm_name=None): + def suspend(self, vm_name=None) -> None: """ Suspend/save the machine. """ self._call_vagrant_command(["suspend", vm_name]) self._cached_conf[vm_name] = None # remove cached configuration - def resume(self, vm_name=None): + def resume(self, vm_name=None) -> None: """ Resume suspended machine. """ self._call_vagrant_command(["resume", vm_name]) self._cached_conf[vm_name] = None # remove cached configuration - def halt(self, vm_name=None, force=False): + def halt(self, vm_name=None, force=False) -> None: """ Halt the Vagrant box. @@ -445,14 +447,14 @@ def halt(self, vm_name=None, force=False): self._call_vagrant_command(["halt", vm_name, force_opt]) self._cached_conf[vm_name] = None # remove cached configuration - def destroy(self, vm_name=None): + def destroy(self, vm_name=None) -> None: """ Terminate the running Vagrant box. """ self._call_vagrant_command(["destroy", vm_name, "--force"]) self._cached_conf[vm_name] = None # remove cached configuration - def status(self, vm_name=None): + def status(self, vm_name=None) -> List[Status]: r""" Return the results of a `vagrant status` call as a list of one or more Status objects. A Status contains the following attributes: @@ -546,7 +548,7 @@ def global_status(self, prune=False): output = self._run_vagrant_command(cmd) return self._parse_global_status(output) - def _normalize_status(self, status, provider): + def _normalize_status(self, status, provider) -> str: """ Normalise VM status to cope with state name being different between providers @@ -562,7 +564,7 @@ def _normalize_status(self, status, provider): return status - def _parse_status(self, output): + def _parse_status(self, output) -> List[Status]: """ Unit testing is so much easier when Vagrant is removed from the equation. @@ -583,7 +585,7 @@ def _parse_status(self, output): return statuses - def _parse_global_status(self, output): + def _parse_global_status(self, output: str) -> List[GlobalStatus]: """ Unit testing is so much easier when Vagrant is removed from the equation. @@ -609,7 +611,7 @@ def _parse_global_status(self, output): vm_id = state = provider = home = None return statuses - def conf(self, ssh_config=None, vm_name=None): + def conf(self, ssh_config=None, vm_name=None) -> Dict[str, str]: """ Parse ssh_config into a dict containing the keys defined in ssh_config, which should include these keys (listed with example values): 'User' @@ -631,15 +633,15 @@ def conf(self, ssh_config=None, vm_name=None): the value returned from ssh_config(). For speed, the configuration parsed from ssh_config is cached for subsequent calls. """ - if self._cached_conf.get(vm_name) is None or ssh_config is not None: + conf = self._cached_conf.get(vm_name) + if conf is None or ssh_config is not None: if ssh_config is None: ssh_config = self.ssh_config(vm_name=vm_name) conf = self._parse_config(ssh_config) self._cached_conf[vm_name] = conf + return conf - return self._cached_conf[vm_name] - - def ssh_config(self, vm_name=None): + def ssh_config(self, vm_name=None) -> str: """ Return the output of 'vagrant ssh-config' which appears to be a valid Host section suitable for use in an ssh config file. @@ -662,7 +664,7 @@ def ssh_config(self, vm_name=None): # capture ssh configuration from vagrant return self._run_vagrant_command(["ssh-config", vm_name]) - def user(self, vm_name=None): + def user(self, vm_name=None) -> Optional[str]: """ Return the ssh user of the vagrant box, e.g. 'vagrant' or None if there is no user in the ssh_config. @@ -672,7 +674,7 @@ def user(self, vm_name=None): """ return self.conf(vm_name=vm_name).get("User") - def hostname(self, vm_name=None): + def hostname(self, vm_name=None) -> Optional[str]: """ Return the vagrant box hostname, e.g. '127.0.0.1' or None if there is no hostname in the ssh_config. @@ -682,7 +684,7 @@ def hostname(self, vm_name=None): """ return self.conf(vm_name=vm_name).get("HostName") - def port(self, vm_name=None): + def port(self, vm_name=None) -> Optional[str]: """ Return the vagrant box ssh port, e.g. '2222' or None if there is no port in the ssh_config. @@ -692,7 +694,7 @@ def port(self, vm_name=None): """ return self.conf(vm_name=vm_name).get("Port") - def keyfile(self, vm_name=None): + def keyfile(self, vm_name=None) -> Optional[str]: """ Return the path to the private key used to log in to the vagrant box or None if there is no keyfile (IdentityFile) in the ssh_config. @@ -705,7 +707,7 @@ def keyfile(self, vm_name=None): """ return self.conf(vm_name=vm_name).get("IdentityFile") - def user_hostname(self, vm_name=None): + def user_hostname(self, vm_name=None) -> str: """ Return a string combining user and hostname, e.g. 'vagrant@127.0.0.1'. This string is suitable for use in an ssh command. If user is None @@ -716,10 +718,13 @@ def user_hostname(self, vm_name=None): has been destroyed. """ user = self.user(vm_name=vm_name) + hostname = self.hostname(vm_name=vm_name) + if hostname is None: + raise ValueError("Missing hostname for vm_name={vm_name!r}") user_prefix = user + "@" if user else "" - return user_prefix + self.hostname(vm_name=vm_name) + return user_prefix + hostname - def user_hostname_port(self, vm_name=None): + def user_hostname_port(self, vm_name=None) -> str: """ Return a string combining user, hostname and port, e.g. 'vagrant@127.0.0.1:2222'. This string is suitable for use with Fabric, @@ -733,11 +738,14 @@ def user_hostname_port(self, vm_name=None): """ user = self.user(vm_name=vm_name) port = self.port(vm_name=vm_name) + hostname = self.hostname(vm_name=vm_name) + if hostname is None: + raise ValueError("Missing hostname for vm_name={vm_name!r}") user_prefix = user + "@" if user else "" port_suffix = ":" + port if port else "" - return user_prefix + self.hostname(vm_name=vm_name) + port_suffix + return user_prefix + hostname + port_suffix - def box_add(self, name, url, provider=None, force=False): + def box_add(self, name, url, provider=None, force=False) -> None: """ Adds a box with given name, from given url. @@ -750,7 +758,7 @@ def box_add(self, name, url, provider=None, force=False): self._call_vagrant_command(cmd) - def box_list(self): + def box_list(self) -> List[Box]: """ Run `vagrant box list --machine-readable` and return a list of Box objects containing the results. A Box object has the following @@ -847,7 +855,7 @@ def snapshot_delete(self, name): """ self._call_vagrant_command(["snapshot", "delete", name]) - def ssh(self, vm_name=None, command=None, extra_ssh_args=None): + def ssh(self, vm_name=None, command=None, extra_ssh_args=None) -> str: """ Execute a command via ssh on the vm specified. command: The command to execute via ssh. @@ -860,7 +868,7 @@ def ssh(self, vm_name=None, command=None, extra_ssh_args=None): return self._run_vagrant_command(cmd) - def _parse_box_list(self, output): + def _parse_box_list(self, output) -> List[Box]: """ Remove Vagrant usage for unit testing """ @@ -891,14 +899,14 @@ def _parse_box_list(self, output): return boxes - def box_update(self, name, provider): + def box_update(self, name, provider) -> None: """ Updates the box matching name and provider. It is an error if no box matches name and provider. """ self._call_vagrant_command(["box", "update", name, provider]) - def box_remove(self, name, provider): + def box_remove(self, name, provider) -> None: """ Removes the box matching name and provider. It is an error if no box matches name and provider. @@ -941,7 +949,7 @@ def plugin_list(self): output = self._run_vagrant_command(["plugin", "list", "--machine-readable"]) return self._parse_plugin_list(output) - def validate(self, directory): + def validate(self, directory) -> subprocess.CompletedProcess: """ This command validates present Vagrantfile. """ @@ -955,7 +963,7 @@ def validate(self, directory): return validate - def _parse_plugin_list(self, output): + def _parse_plugin_list(self, output) -> List[Plugin]: """ Remove Vagrant from the equation for unit testing. """ @@ -992,7 +1000,7 @@ def _parse_plugin_list(self, output): return plugins - def _parse_machine_readable_output(self, output): + def _parse_machine_readable_output(self, output: str) -> List[List[str]]: """Parse machine readable output from vagrant commands. param output: a string containing the output of a vagrant command with the `--machine-readable` option. @@ -1023,7 +1031,7 @@ def _parse_machine_readable_output(self, output): parsed_lines = list(filter(lambda x: x[2] not in unneeded_kind, parsed_lines)) return parsed_lines - def _parse_config(self, ssh_config): + def _parse_config(self, ssh_config: str) -> Dict[str, str]: r""" This lame parser does not parse the full grammar of an ssh config file. It makes assumptions that are (hopefully) correct for the output @@ -1058,7 +1066,7 @@ def _parse_config(self, ssh_config): conf[key] = value.strip('"') return conf - def _make_vagrant_command(self, args): + def _make_vagrant_command(self, args: List[Union[str, None]]) -> List[str]: if self._vagrant_exe is None: self._vagrant_exe = get_vagrant_executable() @@ -1070,7 +1078,7 @@ def _make_vagrant_command(self, args): # when it is not specified. return [self._vagrant_exe] + [arg for arg in args if arg is not None] - def _call_vagrant_command(self, args): + def _call_vagrant_command(self, args) -> None: """ Run a vagrant command. Return None. args: A sequence of arguments to a vagrant command line. @@ -1083,7 +1091,7 @@ def _call_vagrant_command(self, args): command, cwd=self.root, stdout=out_fh, stderr=err_fh, env=self.env ) - def _run_vagrant_command(self, args): + def _run_vagrant_command(self, args) -> str: """ Run a vagrant command and return its stdout. args: A sequence of arguments to a vagrant command line. @@ -1099,7 +1107,7 @@ def _run_vagrant_command(self, args): ) ) - def _stream_vagrant_command(self, args): + def _stream_vagrant_command(self, args) -> Iterator[str]: """ Execute a vagrant command, returning a generator of the output lines. Caller should consume the entire generator to avoid the hanging the @@ -1123,8 +1131,9 @@ def _stream_vagrant_command(self, args): # Iterate over output lines. # See http://stackoverflow.com/questions/2715847/python-read-streaming-input-from-subprocess-communicate#17698359 with subprocess.Popen(**sp_args) as p: - with p.stdout: - for line in iter(p.stdout.readline, b""): + stdout = typing.cast(typing.IO, p.stdout) + with stdout: + for line in iter(stdout.readline, b""): yield compat.decode(line) # if PY3 decode bytestrings p.wait() # Raise CalledProcessError for consistency with _call_vagrant_command @@ -1137,22 +1146,22 @@ class SandboxVagrant(Vagrant): Support for sandbox mode using the Sahara gem (https://github.com/jedi4ever/sahara). """ - def _run_sandbox_command(self, args): + def _run_sandbox_command(self, args) -> str: return self._run_vagrant_command(["sandbox"] + list(args)) - def sandbox_commit(self, vm_name=None): + def sandbox_commit(self, vm_name=None) -> None: """ Permanently writes all the changes made to the VM. """ self._run_sandbox_command(["commit", vm_name]) - def sandbox_off(self, vm_name=None): + def sandbox_off(self, vm_name=None) -> None: """ Disables the sandbox mode. """ self._run_sandbox_command(["off", vm_name]) - def sandbox_on(self, vm_name=None): + def sandbox_on(self, vm_name=None) -> None: """ Enables the sandbox mode. @@ -1161,13 +1170,13 @@ def sandbox_on(self, vm_name=None): """ self._run_sandbox_command(["on", vm_name]) - def sandbox_rollback(self, vm_name=None): + def sandbox_rollback(self, vm_name=None) -> None: """ Reverts all the changes made to the VM since the last commit. """ self._run_sandbox_command(["rollback", vm_name]) - def sandbox_status(self, vm_name=None): + def sandbox_status(self, vm_name=None) -> str: """ Returns the status of the sandbox mode. @@ -1180,7 +1189,7 @@ def sandbox_status(self, vm_name=None): vagrant_sandbox_output = self._run_sandbox_command(["status", vm_name]) return self._parse_vagrant_sandbox_status(vagrant_sandbox_output) - def _parse_vagrant_sandbox_status(self, vagrant_output): + def _parse_vagrant_sandbox_status(self, vagrant_output) -> str: """ Returns the status of the sandbox mode given output from 'vagrant sandbox status'. diff --git a/src/vagrant/compat.py b/src/vagrant/compat.py index 22f6f17..f520acf 100644 --- a/src/vagrant/compat.py +++ b/src/vagrant/compat.py @@ -13,7 +13,7 @@ PY2 = sys.version_info[0] == 2 -def decode(value): +def decode(value: bytes) -> str: """Decode binary data to text if needed (for Python 3). Use with the functions that return in Python 2 value of `str` type and for Python 3 encoded bytes. @@ -21,4 +21,4 @@ def decode(value): :param value: Encoded bytes for Python 3 and `str` for Python 2. :return: Value as a text. """ - return value.decode(locale.getpreferredencoding()) if not PY2 else value + return value.decode(locale.getpreferredencoding()) if not PY2 else value # type: ignore diff --git a/src/vagrant/test.py b/src/vagrant/test.py index 2287144..457319d 100644 --- a/src/vagrant/test.py +++ b/src/vagrant/test.py @@ -4,7 +4,7 @@ It also removes some of the boilerplate involved in writing tests that leverage vagrant boxes. """ -from typing import Dict, List +from typing import Dict, List, Optional from unittest import TestCase from vagrant import Vagrant, stderr_cm @@ -22,7 +22,7 @@ class VagrantTestCase(TestCase): """ vagrant_boxes: List[str] = [] - vagrant_root = None + vagrant_root: Optional[str] = None restart_boxes = False __initial_box_statuses: Dict[str, str] = {} @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs): self.vagrant_boxes = boxes super().__init__(*args, **kwargs) - def assertBoxStatus(self, box, status): + def assertBoxStatus(self, box: str, status: str) -> None: """Assertion for a box status""" box_status = [s.state for s in self.vagrant.status() if s.name == box][0] if box_status != status: @@ -51,19 +51,19 @@ def assertBoxStatus(self, box, status): "{} has status {}, not {}".format(box, box_status, status) ) - def assertBoxUp(self, box): + def assertBoxUp(self, box: str) -> None: """Assertion for a box being up""" self.assertBoxStatus(box, Vagrant.RUNNING) - def assertBoxSuspended(self, box): + def assertBoxSuspended(self, box: str) -> None: """Assertion for a box being up""" self.assertBoxStatus(box, Vagrant.SAVED) - def assertBoxHalted(self, box): + def assertBoxHalted(self, box: str) -> None: """Assertion for a box being up""" self.assertBoxStatus(box, Vagrant.POWEROFF) - def assertBoxNotCreated(self, box): + def assertBoxNotCreated(self, box: str) -> None: """Assertion for a box being up""" self.assertBoxStatus(box, Vagrant.NOT_CREATED) @@ -74,18 +74,18 @@ def run(self, result=None): self.tearDownOnce() return run - def setUpOnce(self): + def setUpOnce(self) -> None: """Collect the box states before starting""" for box_name in self.vagrant_boxes: s = self.vagrant.status(vm_name=box_name)[0] self.__initial_box_statuses[box_name] = s.state - def tearDownOnce(self): + def tearDownOnce(self) -> None: """Restore all boxes to their initial states after running all tests, unless tearDown handled it already""" if not self.restart_boxes: self.restore_box_states() - def restore_box_states(self): + def restore_box_states(self) -> None: """Restores all boxes to their original states""" for box_name in self.vagrant_boxes: action = self.__cleanup_actions.get(self.__initial_box_statuses[box_name]) diff --git a/tests/test_vagrant.py b/tests/test_vagrant.py index d35089d..f35d2b6 100644 --- a/tests/test_vagrant.py +++ b/tests/test_vagrant.py @@ -22,7 +22,8 @@ import sys import tempfile import time -from typing import Generator +import typing +from typing import Generator, List, Optional import pytest @@ -46,7 +47,7 @@ def get_provider() -> str: # location of Vagrant executable -VAGRANT_EXE = vagrant.get_vagrant_executable() +VAGRANT_EXE = typing.cast(str, vagrant.get_vagrant_executable()) # location of a test file on the created box by provisioning in vm_Vagrantfile TEST_FILE_PATH = "/home/vagrant/python_vagrant_test_file" @@ -106,7 +107,7 @@ def fixture_test_dir() -> Generator[str, None, None]: shutil.rmtree(my_dir) -def list_box_names(): +def list_box_names() -> List[str]: """ Return a list of the currently installed vagrant box names. This is implemented outside of `vagrant.Vagrant`, so that it will still work @@ -714,7 +715,7 @@ def test_streaming_output(vm_dir): v.up(vm_name="incorrect-name") streaming_up = False - for line in v.up(stream_output=True): + for line in v.up(stream_output=True): # type: ignore print("output line:", line) if test_string in line: streaming_up = True @@ -722,7 +723,7 @@ def test_streaming_output(vm_dir): assert streaming_up streaming_reload = False - for line in v.reload(stream_output=True): + for line in v.reload(stream_output=True): # type: ignore print("output line:", line) if test_string in line: streaming_reload = True @@ -759,7 +760,7 @@ def test_vagrant_version(): assert version_result is True -def _execute_command_in_vm(v, command): +def _execute_command_in_vm(v, command) -> str: """ Run command via ssh on the test vagrant box. Returns a tuple of the return code and output of the command. @@ -769,7 +770,7 @@ def _execute_command_in_vm(v, command): return compat.decode(subprocess.check_output(ssh_command, cwd=v.root)) -def _write_test_file(v, file_contents): +def _write_test_file(v, file_contents) -> None: """ Writes given contents to the test file. """ @@ -777,7 +778,7 @@ def _write_test_file(v, file_contents): _execute_command_in_vm(v, command) -def _read_test_file(v): +def _read_test_file(v) -> Optional[str]: """ Returns the contents of the test file stored in the VM or None if there is no file. @@ -790,6 +791,6 @@ def _read_test_file(v): return None -def _plugin_installed(v, plugin_name): +def _plugin_installed(v, plugin_name) -> bool: plugins = v.plugin_list() return plugin_name in [plugin.name for plugin in plugins] diff --git a/tests/test_vagrant_test_case.py b/tests/test_vagrant_test_case.py index 7519db6..43945ee 100644 --- a/tests/test_vagrant_test_case.py +++ b/tests/test_vagrant_test_case.py @@ -9,7 +9,7 @@ from .test_vagrant import TEST_BOX_NAME -def get_vagrant_root(test_vagrant_root_path): +def get_vagrant_root(test_vagrant_root_path) -> str: return ( os.path.dirname(os.path.realpath(__file__)) + "/vagrantfiles/"