From 2e29be99f0f9355bcf864d1f48e81759b248c8af Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Tue, 4 Feb 2020 09:49:11 +0100 Subject: [PATCH] Refactor Capture classes --- src/_pytest/capture.py | 66 +++++++++++++++++++++++++++-------------- src/_pytest/compat.py | 2 +- testing/test_capture.py | 6 ++-- 3 files changed, 47 insertions(+), 27 deletions(-) diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index fbba0ecb5a7..040dd7afb5d 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -13,13 +13,18 @@ from typing import Generator from typing import Iterable from typing import Optional +from typing import Union import pytest -from _pytest.compat import CaptureAndPassthroughIO from _pytest.compat import CaptureIO +from _pytest.compat import PassthroughCaptureIO +from _pytest.compat import TYPE_CHECKING from _pytest.config import Config from _pytest.fixtures import FixtureRequest +if TYPE_CHECKING: + from typing import Type + patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} @@ -89,16 +94,31 @@ def __repr__(self): self._method, self._global_capturing, self._capture_fixture ) - def _getcapture(self, method): - if method == "fd": - return MultiCapture(out=True, err=True, Capture=FDCapture) - elif method == "sys": - return MultiCapture(out=True, err=True, Capture=SysCapture) - elif method == "no": + def _getcapture(self, method: str) -> "MultiCapture": + modes = method.split("-") + + if "no" in modes: + if len(modes) > 1: + raise ValueError("'no' cannot be combined with other modes") return MultiCapture(out=False, err=False, in_=False) - elif method == "tee-sys": - return MultiCapture(out=True, err=True, in_=False, Capture=TeeSysCapture) - raise ValueError("unknown capturing method: %r" % method) # pragma: no cover + + if "tee" in modes: + if "sys" not in modes: + raise ValueError("'tee' only works with 'sys'") + + capture = TeeSysCapture # type: Union[Type[Capture]] + elif "fd" in modes: + if "sys" in modes: + raise ValueError("'fd' and 'sys' cannot be combined") + capture = FDCapture + elif "sys" in modes: + if "fd" in modes: + raise ValueError("'sys' and 'fd' cannot be combined") + capture = SysCapture + else: + raise ValueError("unknown capturing method: {}".format(method)) + + return MultiCapture(out=True, err=True, Capture=capture) def is_capturing(self): if self.is_globally_capturing(): @@ -512,7 +532,11 @@ class NoCapture: __init__ = start = done = suspend = resume = lambda *args: None -class FDCaptureBinary: +class Capture: + pass + + +class FDCaptureBinary(Capture): """Capture IO to/from a given os-level filedescriptor. snap() produces `bytes` @@ -613,8 +637,7 @@ def snap(self): return res -class SysCapture: - +class SysCapture(Capture): EMPTY_BUFFER = str() _state = None @@ -664,16 +687,13 @@ def writeorg(self, data): class TeeSysCapture(SysCapture): - def __init__(self, fd, tmpfile=None): - name = patchsysdict[fd] - self._old = getattr(sys, name) - self.name = name - if tmpfile is None: - if name == "stdin": - tmpfile = DontReadFromInput() - else: - tmpfile = CaptureAndPassthroughIO(self._old) - self.tmpfile = tmpfile + def __init__(self, fd: int) -> None: + old = getattr(sys, patchsysdict[fd]) + if fd == 0: + super().__init__(fd) + else: + super().__init__(fd, PassthroughCaptureIO(old)) + assert self._old == old, (self._old, old) class SysCaptureBinary(SysCapture): diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 0fc48bdba44..41b0f8959bd 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -380,7 +380,7 @@ def getvalue(self) -> str: return self.buffer.getvalue().decode("UTF-8") -class CaptureAndPassthroughIO(CaptureIO): +class PassthroughCaptureIO(CaptureIO): def __init__(self, other: IO) -> None: self._other = other super().__init__() diff --git a/testing/test_capture.py b/testing/test_capture.py index 1a70cb1a566..5fc52b44a34 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -822,10 +822,10 @@ def test_write_bytes_to_buffer(self): assert f.getvalue() == "foo\r\n" -class TestCaptureAndPassthroughIO(TestCaptureIO): +class TestPassthroughCaptureIO(TestCaptureIO): def test_text(self): sio = io.StringIO() - f = capture.CaptureAndPassthroughIO(sio) + f = capture.PassthroughCaptureIO(sio) f.write("hello") s1 = f.getvalue() assert s1 == "hello" @@ -836,7 +836,7 @@ def test_text(self): def test_unicode_and_str_mixture(self): sio = io.StringIO() - f = capture.CaptureAndPassthroughIO(sio) + f = capture.PassthroughCaptureIO(sio) f.write("\u00f6") pytest.raises(TypeError, f.write, b"hello")