diff --git a/changelog/6566.bugfix.rst b/changelog/6566.bugfix.rst new file mode 100644 index 00000000000..4af976f2268 --- /dev/null +++ b/changelog/6566.bugfix.rst @@ -0,0 +1 @@ +Fix ``EncodedFile.writelines`` to call the underlying buffer's ``writelines`` method. diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index c79bfeef024..ccbeb0884e0 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -9,6 +9,8 @@ import sys from io import UnsupportedOperation from tempfile import TemporaryFile +from typing import BinaryIO +from typing import Iterable import pytest from _pytest.compat import CaptureIO @@ -413,30 +415,27 @@ def safe_text_dupfile(f, mode, default_encoding="UTF8"): class EncodedFile: errors = "strict" # possibly needed by py3 code (issue555) - def __init__(self, buffer, encoding): + def __init__(self, buffer: BinaryIO, encoding: str) -> None: self.buffer = buffer self.encoding = encoding - def write(self, obj): - if isinstance(obj, str): - obj = obj.encode(self.encoding, "replace") - else: + def write(self, s: str) -> int: + if not isinstance(s, str): raise TypeError( - "write() argument must be str, not {}".format(type(obj).__name__) + "write() argument must be str, not {}".format(type(s).__name__) ) - return self.buffer.write(obj) + return self.buffer.write(s.encode(self.encoding, "replace")) - def writelines(self, linelist): - data = "".join(linelist) - self.write(data) + def writelines(self, lines: Iterable[str]) -> None: + self.buffer.writelines(x.encode(self.encoding, "replace") for x in lines) @property - def name(self): + def name(self) -> str: """Ensure that file.name is a string.""" return repr(self.buffer) @property - def mode(self): + def mode(self) -> str: return self.buffer.mode.replace("b", "") def __getattr__(self, name): diff --git a/testing/test_capture.py b/testing/test_capture.py index 7d459e91c75..9261c8441ed 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -7,6 +7,8 @@ import textwrap from io import StringIO from io import UnsupportedOperation +from typing import BinaryIO +from typing import Generator from typing import List from typing import TextIO @@ -831,7 +833,7 @@ def test_dontreadfrominput(): @pytest.fixture -def tmpfile(testdir): +def tmpfile(testdir) -> Generator[BinaryIO, None, None]: f = testdir.makepyfile("").open("wb+") yield f if not f.closed: @@ -1497,3 +1499,15 @@ def test_fails(): def test_stderr_write_returns_len(capsys): """Write on Encoded files, namely captured stderr, should return number of characters written.""" assert sys.stderr.write("Foo") == 3 + + +def test_encodedfile_writelines(tmpfile: BinaryIO) -> None: + ef = capture.EncodedFile(tmpfile, "utf-8") + with pytest.raises(AttributeError): + ef.writelines([b"line1", b"line2"]) # type: ignore[list-item] # noqa: F821 + assert ef.writelines(["line1", "line2"]) is None # type: ignore[func-returns-value] # noqa: F821 + tmpfile.seek(0) + assert tmpfile.read() == b"line1line2" + tmpfile.close() + with pytest.raises(ValueError): + ef.read()