Skip to content

Commit 24f78eb

Browse files
elkhadiyAnOctopus
authored andcommitted
Fix sys.stdout overriding in mypy.api
Overriding sys.stdout and sys.stderr in mypy.api is not threadsafe. This causes problems sometimes when using the api in pyls for example.
1 parent e0b1329 commit 24f78eb

File tree

7 files changed

+107
-82
lines changed

7 files changed

+107
-82
lines changed

mypy/api.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,36 +44,31 @@
4444

4545
import sys
4646
from io import StringIO
47-
from typing import List, Tuple, Callable
47+
from typing import List, Tuple, Union, TextIO, Callable
48+
from mypy_extensions import DefaultArg
4849

4950

50-
def _run(f: Callable[[], None]) -> Tuple[str, str, int]:
51-
old_stdout = sys.stdout
52-
new_stdout = StringIO()
53-
sys.stdout = new_stdout
51+
def _run(f: Callable[[TextIO, TextIO], None]) -> Tuple[str, str, int]:
5452

55-
old_stderr = sys.stderr
56-
new_stderr = StringIO()
57-
sys.stderr = new_stderr
53+
stdout = StringIO()
54+
stderr = StringIO()
5855

5956
try:
60-
f()
57+
f(stdout, stderr)
6158
exit_status = 0
6259
except SystemExit as system_exit:
6360
exit_status = system_exit.code
64-
finally:
65-
sys.stdout = old_stdout
66-
sys.stderr = old_stderr
6761

68-
return new_stdout.getvalue(), new_stderr.getvalue(), exit_status
62+
return stdout.getvalue(), stderr.getvalue(), exit_status
6963

7064

7165
def run(args: List[str]) -> Tuple[str, str, int]:
7266
# Lazy import to avoid needing to import all of mypy to call run_dmypy
7367
from mypy.main import main
74-
return _run(lambda: main(None, args=args))
68+
return _run(lambda stdout, stderr: main(None, args=args,
69+
stdout=stdout, stderr=stderr))
7570

7671

7772
def run_dmypy(args: List[str]) -> Tuple[str, str, int]:
7873
from mypy.dmypy import main
79-
return _run(lambda: main(args))
74+
return _run(lambda stdout, stderr: main(args))

mypy/build.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import types
2626

2727
from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List,
28-
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable)
28+
Mapping, NamedTuple, Optional, Set, Tuple, Union, Callable, TextIO)
2929
MYPY = False
3030
if MYPY:
3131
from typing import ClassVar
@@ -128,6 +128,8 @@ def build(sources: List[BuildSource],
128128
alt_lib_path: Optional[str] = None,
129129
flush_errors: Optional[Callable[[List[str], bool], None]] = None,
130130
fscache: Optional[FileSystemCache] = None,
131+
stdout: TextIO = sys.stdout,
132+
stderr: TextIO = sys.stderr,
131133
) -> BuildResult:
132134
"""Analyze a program.
133135
@@ -161,7 +163,7 @@ def default_flush_errors(new_messages: List[str], is_serious: bool) -> None:
161163
flush_errors = flush_errors or default_flush_errors
162164

163165
try:
164-
result = _build(sources, options, alt_lib_path, flush_errors, fscache)
166+
result = _build(sources, options, alt_lib_path, flush_errors, fscache, stdout, stderr)
165167
result.errors = messages
166168
return result
167169
except CompileError as e:
@@ -180,6 +182,8 @@ def _build(sources: List[BuildSource],
180182
alt_lib_path: Optional[str],
181183
flush_errors: Callable[[List[str], bool], None],
182184
fscache: Optional[FileSystemCache],
185+
stdout: TextIO,
186+
stderr: TextIO,
183187
) -> BuildResult:
184188
# This seems the most reasonable place to tune garbage collection.
185189
gc.set_threshold(150 * 1000)
@@ -197,7 +201,7 @@ def _build(sources: List[BuildSource],
197201

198202
source_set = BuildSourceSet(sources)
199203
errors = Errors(options.show_error_context, options.show_column_numbers)
200-
plugin, snapshot = load_plugins(options, errors)
204+
plugin, snapshot = load_plugins(options, errors, stdout)
201205

202206
# Construct a build manager object to hold state during the build.
203207
#
@@ -212,12 +216,14 @@ def _build(sources: List[BuildSource],
212216
plugins_snapshot=snapshot,
213217
errors=errors,
214218
flush_errors=flush_errors,
215-
fscache=fscache)
219+
fscache=fscache,
220+
stdout=stdout,
221+
stderr=stderr)
216222
manager.trace(repr(options))
217223

218224
reset_global_state()
219225
try:
220-
graph = dispatch(sources, manager)
226+
graph = dispatch(sources, manager, stdout)
221227
if not options.fine_grained_incremental:
222228
TypeState.reset_all_subtype_caches()
223229
return BuildResult(manager, graph)
@@ -319,7 +325,7 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
319325
return toplevel_priority
320326

321327

322-
def load_plugins(options: Options, errors: Errors) -> Tuple[Plugin, Dict[str, str]]:
328+
def load_plugins(options: Options, errors: Errors, stdout: TextIO = sys.stdout) -> Tuple[Plugin, Dict[str, str]]:
323329
"""Load all configured plugins.
324330
325331
Return a plugin that encapsulates all plugins chained together. Always
@@ -383,7 +389,7 @@ def plugin_error(message: str) -> None:
383389
try:
384390
plugin_type = getattr(module, func_name)(__version__)
385391
except Exception:
386-
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
392+
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path), file=stdout)
387393
raise # Propagate to display traceback
388394

389395
if not isinstance(plugin_type, type):
@@ -398,7 +404,7 @@ def plugin_error(message: str) -> None:
398404
custom_plugins.append(plugin_type(options))
399405
snapshot[module_name] = take_module_snapshot(module)
400406
except Exception:
401-
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
407+
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__), file=stdout)
402408
raise # Propagate to display traceback
403409
# Custom plugins take precedence over the default plugin.
404410
return ChainedPlugin(options, custom_plugins + [default_plugin]), snapshot
@@ -496,8 +502,10 @@ def __init__(self, data_dir: str,
496502
errors: Errors,
497503
flush_errors: Callable[[List[str], bool], None],
498504
fscache: FileSystemCache,
505+
stdout: TextIO,
506+
stderr: TextIO,
499507
) -> None:
500-
super().__init__()
508+
super().__init__(stdout, stderr)
501509
self.start_time = time.time()
502510
self.data_dir = data_dir
503511
self.errors = errors
@@ -558,7 +566,7 @@ def __init__(self, data_dir: str,
558566
self.plugin = plugin
559567
self.plugins_snapshot = plugins_snapshot
560568
self.old_plugins_snapshot = read_plugins_snapshot(self)
561-
self.quickstart_state = read_quickstart_file(options)
569+
self.quickstart_state = read_quickstart_file(options, self.stdout)
562570

563571
def dump_stats(self) -> None:
564572
self.log("Stats:")
@@ -904,7 +912,7 @@ def read_plugins_snapshot(manager: BuildManager) -> Optional[Dict[str, str]]:
904912
return snapshot
905913

906914

907-
def read_quickstart_file(options: Options) -> Optional[Dict[str, Tuple[float, int, str]]]:
915+
def read_quickstart_file(options: Options, stdout: TextIO = sys.stdout) -> Optional[Dict[str, Tuple[float, int, str]]]:
908916
quickstart = None # type: Optional[Dict[str, Tuple[float, int, str]]]
909917
if options.quickstart_file:
910918
# This is very "best effort". If the file is missing or malformed,
@@ -918,7 +926,7 @@ def read_quickstart_file(options: Options) -> Optional[Dict[str, Tuple[float, in
918926
for file, (x, y, z) in raw_quickstart.items():
919927
quickstart[file] = (x, y, z)
920928
except Exception as e:
921-
print("Warning: Failed to load quickstart file: {}\n".format(str(e)))
929+
print("Warning: Failed to load quickstart file: {}\n".format(str(e)), file=stdout)
922930
return quickstart
923931

924932

@@ -1769,7 +1777,7 @@ def wrap_context(self, check_blockers: bool = True) -> Iterator[None]:
17691777
except CompileError:
17701778
raise
17711779
except Exception as err:
1772-
report_internal_error(err, self.path, 0, self.manager.errors, self.options)
1780+
report_internal_error(err, self.path, 0, self.manager.errors, self.options, self.manager.stdout, self.manager.stderr)
17731781
self.manager.errors.set_import_context(save_import_context)
17741782
# TODO: Move this away once we've removed the old semantic analyzer?
17751783
if check_blockers:
@@ -2429,7 +2437,7 @@ def log_configuration(manager: BuildManager) -> None:
24292437
# The driver
24302438

24312439

2432-
def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
2440+
def dispatch(sources: List[BuildSource], manager: BuildManager, stdout: TextIO = sys.stdout) -> Graph:
24332441
log_configuration(manager)
24342442

24352443
t0 = time.time()
@@ -2454,11 +2462,11 @@ def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
24542462
fm_cache_size=len(manager.find_module_cache.results),
24552463
)
24562464
if not graph:
2457-
print("Nothing to do?!")
2465+
print("Nothing to do?!", file=stdout)
24582466
return graph
24592467
manager.log("Loaded graph with %d nodes (%.3f sec)" % (len(graph), t1 - t0))
24602468
if manager.options.dump_graph:
2461-
dump_graph(graph)
2469+
dump_graph(graph, stdout)
24622470
return graph
24632471

24642472
# Fine grained dependencies that didn't have an associated module in the build
@@ -2480,7 +2488,7 @@ def dispatch(sources: List[BuildSource], manager: BuildManager) -> Graph:
24802488
manager.log("Error reading fine-grained dependencies cache -- aborting cache load")
24812489
manager.cache_enabled = False
24822490
manager.log("Falling back to full run -- reloading graph...")
2483-
return dispatch(sources, manager)
2491+
return dispatch(sources, manager, stdout)
24842492

24852493
# If we are loading a fine-grained incremental mode cache, we
24862494
# don't want to do a real incremental reprocess of the
@@ -2528,7 +2536,7 @@ def dumps(self) -> str:
25282536
json.dumps(self.deps))
25292537

25302538

2531-
def dump_graph(graph: Graph) -> None:
2539+
def dump_graph(graph: Graph, stdout: TextIO = sys.stdout) -> None:
25322540
"""Dump the graph as a JSON string to stdout.
25332541
25342542
This copies some of the work by process_graph()
@@ -2562,7 +2570,7 @@ def dump_graph(graph: Graph) -> None:
25622570
if (dep_id != node.node_id and
25632571
(dep_id not in node.deps or pri < node.deps[dep_id])):
25642572
node.deps[dep_id] = pri
2565-
print("[" + ",\n ".join(node.dumps() for node in nodes) + "\n]")
2573+
print("[" + ",\n ".join(node.dumps() for node in nodes) + "\n]", file=stdout)
25662574

25672575

25682576
def load_graph(sources: List[BuildSource], manager: BuildManager,

mypy/errors.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import traceback
44
from collections import OrderedDict, defaultdict
55

6-
from typing import Tuple, List, TypeVar, Set, Dict, Optional
6+
from typing import Tuple, List, TypeVar, Set, Dict, Optional, TextIO
77

88
from mypy.scope import Scope
99
from mypy.options import Options
@@ -586,7 +586,7 @@ def remove_path_prefix(path: str, prefix: Optional[str]) -> str:
586586

587587

588588
def report_internal_error(err: Exception, file: Optional[str], line: int,
589-
errors: Errors, options: Options) -> None:
589+
errors: Errors, options: Options, stdout: TextIO = sys.stdout, stderr: TextIO = sys.stderr) -> None:
590590
"""Report internal error and exit.
591591
592592
This optionally starts pdb or shows a traceback.
@@ -597,7 +597,7 @@ def report_internal_error(err: Exception, file: Optional[str], line: int,
597597
for msg in errors.new_messages():
598598
print(msg)
599599
except Exception as e:
600-
print("Failed to dump errors:", repr(e), file=sys.stderr)
600+
print("Failed to dump errors:", repr(e), file=stderr)
601601

602602
# Compute file:line prefix for official-looking error messages.
603603
if file:
@@ -612,11 +612,11 @@ def report_internal_error(err: Exception, file: Optional[str], line: int,
612612
print('{}error: INTERNAL ERROR --'.format(prefix),
613613
'please report a bug at https://github.com/python/mypy/issues',
614614
'version: {}'.format(mypy_version),
615-
file=sys.stderr)
615+
file=stderr)
616616

617617
# If requested, drop into pdb. This overrides show_tb.
618618
if options.pdb:
619-
print('Dropping into pdb', file=sys.stderr)
619+
print('Dropping into pdb', file=stderr)
620620
import pdb
621621
pdb.post_mortem(sys.exc_info()[2])
622622

@@ -627,15 +627,15 @@ def report_internal_error(err: Exception, file: Optional[str], line: int,
627627
if not options.pdb:
628628
print('{}: note: please use --show-traceback to print a traceback '
629629
'when reporting a bug'.format(prefix),
630-
file=sys.stderr)
630+
file=stderr)
631631
else:
632632
tb = traceback.extract_stack()[:-2]
633633
tb2 = traceback.extract_tb(sys.exc_info()[2])
634634
print('Traceback (most recent call last):')
635635
for s in traceback.format_list(tb + tb2):
636636
print(s.rstrip('\n'))
637-
print('{}: {}'.format(type(err).__name__, err))
638-
print('{}: note: use --pdb to drop into pdb'.format(prefix), file=sys.stderr)
637+
print('{}: {}'.format(type(err).__name__, err), file=stdout)
638+
print('{}: note: use --pdb to drop into pdb'.format(prefix), file=stderr)
639639

640640
# Exit. The caller has nothing more to say.
641641
# We use exit code 2 to signal that this is no ordinary error.

0 commit comments

Comments
 (0)