Skip to content

Commit c465807

Browse files
committed
Fix sys.stdout overriding in mypy.api (#6125)
Overriding sys.stdout and sys.stderr in mypy.api is not threadsafe. This causes problems sometimes when using the api in pyls for example.
1 parent fa41d19 commit c465807

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

mypy/api.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,31 @@
4444

4545
import sys
4646
from io import StringIO
47-
from typing import List, Tuple, Callable
47+
from typing import List, Tuple, Union, TextIO, Callable
48+
from mypy_extensions import DefaultArg
4849

4950

50-
def _run(f: Callable[[], None]) -> Tuple[str, str, int]:
51-
old_stdout = sys.stdout
52-
new_stdout = StringIO()
53-
sys.stdout = new_stdout
51+
def _run(f: Callable[[TextIO, TextIO], None]) -> Tuple[str, str, int]:
5452

55-
old_stderr = sys.stderr
56-
new_stderr = StringIO()
57-
sys.stderr = new_stderr
53+
stdout = StringIO()
54+
stderr = StringIO()
5855

5956
try:
60-
f()
57+
f(stdout, stderr)
6158
exit_status = 0
6259
except SystemExit as system_exit:
6360
exit_status = system_exit.code
64-
finally:
65-
sys.stdout = old_stdout
66-
sys.stderr = old_stderr
6761

68-
return new_stdout.getvalue(), new_stderr.getvalue(), exit_status
62+
return stdout.getvalue(), stderr.getvalue(), exit_status
6963

7064

7165
def run(args: List[str]) -> Tuple[str, str, int]:
7266
# Lazy import to avoid needing to import all of mypy to call run_dmypy
7367
from mypy.main import main
74-
return _run(lambda: main(None, args=args))
68+
return _run(lambda stdout, stderr: main(None, args=args,
69+
stdout=stdout, stderr=stderr))
7570

7671

7772
def run_dmypy(args: List[str]) -> Tuple[str, str, int]:
7873
from mypy.dmypy import main
79-
return _run(lambda: main(args))
74+
return _run(lambda stdout, stderr: main(args))

mypy/main.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import sys
1010
import time
1111

12-
from typing import Any, Dict, List, Mapping, Optional, Tuple
12+
from typing import Any, Dict, List, Mapping, Optional, Tuple, TextIO
13+
from io import StringIO
1314

1415
from mypy import build
1516
from mypy import defaults
@@ -46,7 +47,9 @@ def stat_proxy(path: str) -> os.stat_result:
4647
return st
4748

4849

49-
def main(script_path: Optional[str], args: Optional[List[str]] = None) -> None:
50+
def main(script_path: Optional[str], args: Optional[List[str]] = None,
51+
stdout: TextIO = sys.stdout,
52+
stderr: TextIO = sys.stderr) -> None:
5053
"""Main entry point to the type checker.
5154
5255
Args:
@@ -69,13 +72,14 @@ def main(script_path: Optional[str], args: Optional[List[str]] = None) -> None:
6972
args = sys.argv[1:]
7073

7174
fscache = FileSystemCache()
72-
sources, options = process_options(args, fscache=fscache)
75+
sources, options = process_options(args, stdout=stdout, stderr=stderr,
76+
fscache=fscache)
7377

7478
messages = []
7579

7680
def flush_errors(new_messages: List[str], serious: bool) -> None:
7781
messages.extend(new_messages)
78-
f = sys.stderr if serious else sys.stdout
82+
f = stderr if serious else stdout
7983
try:
8084
for msg in new_messages:
8185
f.write(msg + '\n')
@@ -99,7 +103,7 @@ def flush_errors(new_messages: List[str], serious: bool) -> None:
99103
(options.config_file,
100104
", ".join("[mypy-%s]" % glob for glob in options.per_module_options.keys()
101105
if glob in options.unused_configs)),
102-
file=sys.stderr)
106+
file=stderr)
103107
if options.junit_xml:
104108
t1 = time.time()
105109
util.write_junit_xml(t1 - t0, serious, messages, options.junit_xml)
@@ -295,6 +299,8 @@ def infer_python_executable(options: Options,
295299

296300

297301
def process_options(args: List[str],
302+
stdout: TextIO = sys.stdout,
303+
stderr: TextIO = sys.stderr,
298304
require_targets: bool = True,
299305
server_options: bool = False,
300306
fscache: Optional[FileSystemCache] = None,
@@ -703,7 +709,7 @@ def add_invertible_flag(flag: str,
703709

704710
# Parse config file first, so command line can override.
705711
options = Options()
706-
parse_config_file(options, config_file)
712+
parse_config_file(options, config_file, stdout, stderr)
707713

708714
# Set strict flags before parsing (if strict mode enabled), so other command
709715
# line options can override.
@@ -785,10 +791,11 @@ def add_invertible_flag(flag: str,
785791
cache = FindModuleCache(search_paths, fscache)
786792
for p in special_opts.packages:
787793
if os.sep in p or os.altsep and os.altsep in p:
788-
fail("Package name '{}' cannot have a slash in it.".format(p))
794+
fail("Package name '{}' cannot have a slash in it.".format(p),
795+
stderr)
789796
p_targets = cache.find_modules_recursive(p)
790797
if not p_targets:
791-
fail("Can't find package '{}'".format(p))
798+
fail("Can't find package '{}'".format(p), stderr)
792799
targets.extend(p_targets)
793800
for m in special_opts.modules:
794801
targets.append(BuildSource(None, m, None))
@@ -801,7 +808,7 @@ def add_invertible_flag(flag: str,
801808
try:
802809
targets = create_source_list(special_opts.files, options, fscache)
803810
except InvalidSourceList as e:
804-
fail(str(e))
811+
fail(str(e), stderr)
805812
return targets, options
806813

807814

@@ -883,7 +890,9 @@ def process_cache_map(parser: argparse.ArgumentParser,
883890
} # type: Final
884891

885892

886-
def parse_config_file(options: Options, filename: Optional[str]) -> None:
893+
def parse_config_file(options: Options, filename: Optional[str],
894+
stdout: TextIO = sys.stdout,
895+
stderr: TextIO = sys.stderr) -> None:
887896
"""Parse a config file into an Options object.
888897
889898
Errors are written to stderr but are not fatal.
@@ -903,7 +912,7 @@ def parse_config_file(options: Options, filename: Optional[str]) -> None:
903912
try:
904913
parser.read(config_file)
905914
except configparser.Error as err:
906-
print("%s: %s" % (config_file, err), file=sys.stderr)
915+
print("%s: %s" % (config_file, err), file=stderr)
907916
else:
908917
file_read = config_file
909918
options.config_file = file_read
@@ -913,27 +922,29 @@ def parse_config_file(options: Options, filename: Optional[str]) -> None:
913922

914923
if 'mypy' not in parser:
915924
if filename or file_read not in defaults.SHARED_CONFIG_FILES:
916-
print("%s: No [mypy] section in config file" % file_read, file=sys.stderr)
925+
print("%s: No [mypy] section in config file" % file_read, file=stderr)
917926
else:
918927
section = parser['mypy']
919928
prefix = '%s: [%s]' % (file_read, 'mypy')
920-
updates, report_dirs = parse_section(prefix, options, section)
929+
updates, report_dirs = parse_section(prefix, options, section,
930+
stdout, stderr)
921931
for k, v in updates.items():
922932
setattr(options, k, v)
923933
options.report_dirs.update(report_dirs)
924934

925935
for name, section in parser.items():
926936
if name.startswith('mypy-'):
927937
prefix = '%s: [%s]' % (file_read, name)
928-
updates, report_dirs = parse_section(prefix, options, section)
938+
updates, report_dirs = parse_section(prefix, options, section,
939+
stdout, stderr)
929940
if report_dirs:
930941
print("%s: Per-module sections should not specify reports (%s)" %
931942
(prefix, ', '.join(s + '_report' for s in sorted(report_dirs))),
932-
file=sys.stderr)
943+
file=stderr)
933944
if set(updates) - PER_MODULE_OPTIONS:
934945
print("%s: Per-module sections should only specify per-module flags (%s)" %
935946
(prefix, ', '.join(sorted(set(updates) - PER_MODULE_OPTIONS))),
936-
file=sys.stderr)
947+
file=stderr)
937948
updates = {k: v for k, v in updates.items() if k in PER_MODULE_OPTIONS}
938949
globs = name[5:]
939950
for glob in globs.split(','):
@@ -947,13 +958,16 @@ def parse_config_file(options: Options, filename: Optional[str]) -> None:
947958
print("%s: Patterns must be fully-qualified module names, optionally "
948959
"with '*' in some components (e.g spam.*.eggs.*)"
949960
% prefix,
950-
file=sys.stderr)
961+
file=stderr)
951962
else:
952963
options.per_module_options[glob] = updates
953964

954965

955966
def parse_section(prefix: str, template: Options,
956-
section: Mapping[str, str]) -> Tuple[Dict[str, object], Dict[str, str]]:
967+
section: Mapping[str, str],
968+
stdout: TextIO = sys.stdout,
969+
stderr: TextIO = sys.stderr
970+
) -> Tuple[Dict[str, object], Dict[str, str]]:
957971
"""Parse one section of a config file.
958972
959973
Returns a dict of option values encountered, and a dict of report directories.
@@ -972,17 +986,17 @@ def parse_section(prefix: str, template: Options,
972986
report_dirs[report_type] = section[key]
973987
else:
974988
print("%s: Unrecognized report type: %s" % (prefix, key),
975-
file=sys.stderr)
989+
file=stderr)
976990
continue
977991
if key.startswith('x_'):
978992
continue # Don't complain about `x_blah` flags
979993
elif key == 'strict':
980994
print("%s: Strict mode is not supported in configuration files: specify "
981995
"individual flags instead (see 'mypy -h' for the list of flags enabled "
982-
"in strict mode)" % prefix, file=sys.stderr)
996+
"in strict mode)" % prefix, file=stderr)
983997
else:
984998
print("%s: Unrecognized option: %s = %s" % (prefix, key, section[key]),
985-
file=sys.stderr)
999+
file=stderr)
9861000
continue
9871001
ct = type(dv)
9881002
v = None # type: Any
@@ -993,32 +1007,32 @@ def parse_section(prefix: str, template: Options,
9931007
try:
9941008
v = ct(section.get(key))
9951009
except argparse.ArgumentTypeError as err:
996-
print("%s: %s: %s" % (prefix, key, err), file=sys.stderr)
1010+
print("%s: %s: %s" % (prefix, key, err), file=stderr)
9971011
continue
9981012
else:
999-
print("%s: Don't know what type %s should have" % (prefix, key), file=sys.stderr)
1013+
print("%s: Don't know what type %s should have" % (prefix, key), file=stderr)
10001014
continue
10011015
except ValueError as err:
1002-
print("%s: %s: %s" % (prefix, key, err), file=sys.stderr)
1016+
print("%s: %s: %s" % (prefix, key, err), file=stderr)
10031017
continue
10041018
if key == 'silent_imports':
10051019
print("%s: silent_imports has been replaced by "
1006-
"ignore_missing_imports=True; follow_imports=skip" % prefix, file=sys.stderr)
1020+
"ignore_missing_imports=True; follow_imports=skip" % prefix, file=stderr)
10071021
if v:
10081022
if 'ignore_missing_imports' not in results:
10091023
results['ignore_missing_imports'] = True
10101024
if 'follow_imports' not in results:
10111025
results['follow_imports'] = 'skip'
10121026
if key == 'almost_silent':
10131027
print("%s: almost_silent has been replaced by "
1014-
"follow_imports=error" % prefix, file=sys.stderr)
1028+
"follow_imports=error" % prefix, file=stderr)
10151029
if v:
10161030
if 'follow_imports' not in results:
10171031
results['follow_imports'] = 'error'
10181032
results[key] = v
10191033
return results, report_dirs
10201034

10211035

1022-
def fail(msg: str) -> None:
1023-
sys.stderr.write('%s\n' % msg)
1036+
def fail(msg: str, stderr: TextIO) -> None:
1037+
stderr.write('%s\n' % msg)
10241038
sys.exit(1)

0 commit comments

Comments
 (0)