2424import argparse
2525import inspect
2626import textwrap
27- from typing import Any , Callable , Dict , Iterable , List , Optional , Type , TypeVar
27+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Type , TypeVar
2828
2929import drgn
3030
@@ -244,6 +244,31 @@ def _call(self,
244244 """
245245 raise NotImplementedError ()
246246
247+ def _valid_input_types (self ) -> Set [str ]:
248+ """
249+ Returns a set of strings which are the canonicalized names of valid input types
250+ for this command
251+ """
252+ assert self .input_type is not None
253+ return {type_canonicalize_name (self .input_type )}
254+
255+ def __input_type_check (
256+ self , objs : Iterable [drgn .Object ]) -> Iterable [drgn .Object ]:
257+ valid_input_types = self ._valid_input_types ()
258+ print ("valid types: " )
259+ print (valid_input_types )
260+ prev_type = None
261+ for obj in objs :
262+ cur_type = type_canonical_name (obj .type_ )
263+ if cur_type not in valid_input_types or (prev_type and
264+ cur_type != prev_type ):
265+ raise CommandError (
266+ self .name ,
267+ f'expected input of type { self .input_type } , but received '
268+ f'type { obj .type_ } ' )
269+ prev_type = cur_type
270+ yield obj
271+
247272 def __invalid_memory_objects_check (self , objs : Iterable [drgn .Object ],
248273 fatal : bool ) -> Iterable [drgn .Object ]:
249274 """
@@ -281,15 +306,19 @@ def call(self, objs: Iterable[drgn.Object]) -> Iterable[drgn.Object]:
281306 # the command is running.
282307 #
283308 try :
284- result = self ._call (objs )
309+ if self .input_type and objs :
310+ result = self ._call (self .__input_type_check (objs ))
311+ else :
312+ result = self ._call (objs )
313+
285314 if result is not None :
286315 #
287316 # The whole point of the SingleInputCommands are that
288317 # they don't stop executing in the first encounter of
289318 # a bad dereference. That's why we check here whether
290319 # the command that we are running is a subclass of
291320 # SingleInputCommand and we set the `fatal` flag
292- # accordinly .
321+ # accordingly .
293322 #
294323 yield from self .__invalid_memory_objects_check (
295324 result , not issubclass (self .__class__ , SingleInputCommand ))
@@ -634,22 +663,6 @@ def pretty_print(self, objs: Iterable[drgn.Object]) -> None:
634663 # pylint: disable=missing-docstring
635664 raise NotImplementedError
636665
637- def check_input_type (self ,
638- objs : Iterable [drgn .Object ]) -> Iterable [drgn .Object ]:
639- """
640- This function acts as a generator, checking that each passed object
641- matches the input type for the command
642- """
643- assert self .input_type is not None
644- type_name = type_canonicalize_name (self .input_type )
645- for obj in objs :
646- if type_canonical_name (obj .type_ ) != type_name :
647- raise CommandError (
648- self .name ,
649- f'expected input of type { self .input_type } , but received '
650- f'type { obj .type_ } ' )
651- yield obj
652-
653666 def _call ( # type: ignore[return]
654667 self ,
655668 objs : Iterable [drgn .Object ]) -> Optional [Iterable [drgn .Object ]]:
@@ -658,7 +671,7 @@ def _call( # type: ignore[return]
658671 verifying the types as we go.
659672 """
660673 assert self .input_type is not None
661- self .pretty_print (self . check_input_type ( objs ) )
674+ self .pretty_print (objs )
662675
663676
664677class Locator (Command ):
@@ -673,6 +686,25 @@ class Locator(Command):
673686
674687 output_type : Optional [str ] = None
675688
689+ def _valid_input_types (self ) -> Set [str ]:
690+ """
691+ Some Locators support multiple input types. Check for InputHandler
692+ implementations to expand the set of valid input types.
693+ """
694+ valid_types = [type_canonicalize_name (self .input_type )]
695+
696+ for (_ , method ) in inspect .getmembers (self , inspect .ismethod ):
697+ if hasattr (method , "input_typename_handled" ):
698+ valid_types .append (
699+ type_canonicalize_name (method .input_typename_handled ))
700+
701+ valid_types += [
702+ type_canonicalize_name (type_ )
703+ for type_ , class_ in Walker .allWalkers .items ()
704+ ]
705+
706+ return set (valid_types )
707+
676708 def no_input (self ) -> Iterable [drgn .Object ]:
677709 # pylint: disable=missing-docstring
678710 raise CommandError (self .name , 'command requires an input' )
0 commit comments