Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions example/commands/sample_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ async def async_bad_name():
"""
cprint("This is async!", "green")

@command("completions")
@argument('name', choices=['harry', 'sally', 'dini', 'pinky', 'maya'],
description="the name you seek")
def completions(name: str):
"Check completions"
cprint(f"{name=}")

return 0


@command
@argument("number", type=int)
Expand Down Expand Up @@ -168,7 +177,7 @@ def do_stuff(self, stuff: int):
def test_mac(mac):
"""
Test command for MAC address parsing without quotes.

Examples:
- test_mac 00:01:21:ab:cd:8f
- test_mac 1234.abcd.5678
Expand All @@ -185,7 +194,7 @@ def test_mac(mac):
def test_mac_pos(mac):
"""
Test command for MAC address parsing as positional argument.

Examples:
- test_mac_pos 00:01:21:ab:cd:8f
- test_mac_pos 1234.abcd.5678
Expand Down
17 changes: 12 additions & 5 deletions nubia/internal/cmdbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ async def run_interactive(self, cmd, args, raw):
return 2

sub_inspection = self.subcommand_metadata(subcommand)
instance, remaining_args = self._create_subcommand_obj(args_dict)
instance, remaining_args = self._create_subcommand_obj(
args_dict)
assert instance
args_dict = remaining_args
key_values = copy.copy(args_dict)
Expand All @@ -276,7 +277,8 @@ async def run_interactive(self, cmd, args, raw):
else:
# not a super-command, use use the function instead
fn = self._fn
positionals = parsed_dict["positionals"] if parsed.positionals != "" else []
positionals = parsed_dict["positionals"] if parsed.positionals != "" else [
]
# We only allow positionals for arguments that have positional=True
# ِ We filter out the OrderedDict this way to ensure we don't lose the
# order of the arguments. We also filter out arguments that have
Expand Down Expand Up @@ -397,10 +399,14 @@ async def run_interactive(self, cmd, args, raw):
for arg, value in args_dict.items():
choices = args_metadata[arg].choices
if choices and not isinstance(choices, Callable):
# Import the pattern matching function
from nubia.internal.helpers import matches_choice_pattern

# Validate the choices in the case of values and list of
# values.
if is_list_type(args_metadata[arg].type):
bad_inputs = [v for v in value if v not in choices]
bad_inputs = [
v for v in value if not matches_choice_pattern(str(v), choices)]
if bad_inputs:
cprint(
f"Argument '{arg}' got an unexpected "
Expand All @@ -409,7 +415,7 @@ async def run_interactive(self, cmd, args, raw):
"red",
)
return 4
elif value not in choices:
elif not matches_choice_pattern(str(value), choices):
cprint(
f"Argument '{arg}' got an unexpected value "
f"'{value}'. Expected one of "
Expand All @@ -421,7 +427,8 @@ async def run_interactive(self, cmd, args, raw):
# arguments appear to be fine, time to run the function
try:
# convert argument names back to match the function signature
args_dict = {args_metadata[k].arg: v for k, v in args_dict.items()}
args_dict = {
args_metadata[k].arg: v for k, v in args_dict.items()}
ctx.cmd = cmd
ctx.raw_cmd = raw
ret = await try_await(fn(**args_dict))
Expand Down
7 changes: 6 additions & 1 deletion nubia/internal/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,13 @@ def _prepare_args_completions(
f', got {choices}')
else:
if parsed_token.last_value:
# Import the pattern matching function
from nubia.internal.helpers import matches_choice_pattern

# Filter choices based on pattern matching
choices = [c for c in arg.choices
if str(c).startswith(parsed_token.last_value)]
if matches_choice_pattern(parsed_token.last_value, [str(c)]) or
str(c).startswith(parsed_token.last_value)]
else:
choices = arg.choices

Expand Down
53 changes: 53 additions & 0 deletions nubia/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,56 @@ def suggestions_msg(suggestions: Optional[Iterable[str]]) -> str:
return ""
else:
return f", Did you mean {', '.join(suggestions[:-1])} or {suggestions[-1]}?"


def matches_choice_pattern(value: str, choices: list) -> bool:
"""
Check if a value matches any of the choices, supporting pattern matching and negation.

The function supports two modes:
1. If the input value contains patterns (~, !, !~), it validates the pattern syntax
2. If the input value is literal, it matches against the choices list

Supported input patterns:
- '!pattern' - negation (reject if pattern matches any choice)
- '~pattern' - regex pattern (accept if pattern matches any choice)
- '!~pattern' - negated regex pattern (reject if pattern matches any choice)
- Regular string matching for exact matches

Args:
value: The value to check (can contain patterns)
choices: List of valid choices (literal values)

Returns:
bool: True if the value is valid, False otherwise
"""
# If the input value contains patterns, validate the pattern syntax
if value.startswith('!~'):
# Negated regex pattern: !~pattern
pattern = value[2:]
try:
# Test if the regex is valid
re.compile(pattern)
return True # Valid regex syntax
except re.error:
return False # Invalid regex syntax

elif value.startswith('~'):
# Regex pattern: ~pattern
pattern = value[1:]
try:
# Test if the regex is valid
re.compile(pattern)
return True # Valid regex syntax
except re.error:
return False # Invalid regex syntax

elif value.startswith('!'):
# Negation pattern: !pattern
literal = value[1:]
# Check if the literal after ! exists in choices
return literal in [str(choice) for choice in choices]

else:
# Regular literal matching
return value in [str(choice) for choice in choices]
116 changes: 116 additions & 0 deletions tests/integration_pattern_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

"""
Integration test to verify pattern matching works end-to-end with nubia.
"""

from pattern_matching_example import pattern_demo, file_demo
from tests.util import TestShell
import asyncio
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'example'))


class PatternMatchingIntegrationTest:
"""Integration test for pattern matching functionality."""

async def test_pattern_matching_integration(self):
"""Test pattern matching through the full nubia framework."""
shell = TestShell([pattern_demo, file_demo])

print("Testing pattern matching integration with nubia...")

# Test cases: (command, expected_result, description)
test_cases = [
# Valid patterns (return None or 0 for success)
("pattern-demo pattern=a", [0, None], "literal match 'a'"),
("pattern-demo pattern=a2", [0, None], "regex match '~a.*'"),
("pattern-demo pattern=a2a1", [0, None], "literal match 'a2a1'"),

# Invalid patterns (should return 4 for validation error)
("pattern-demo pattern=a1", 4, "negated by '!a1'"),
("pattern-demo pattern=b1", 4, "negated by '!~b.*'"),
("pattern-demo pattern=c", 4, "no matching pattern"),
]

all_passed = True

for cmd, expected_result, description in test_cases:
try:
result = await shell.run_interactive_line(cmd)
if isinstance(expected_result, list):
if result in expected_result:
print(f"✓ {description}: {cmd}")
else:
print(
f"✗ {description}: {cmd} (expected {expected_result}, got {result})")
all_passed = False
else:
if result == expected_result:
print(f"✓ {description}: {cmd}")
else:
print(
f"✗ {description}: {cmd} (expected {expected_result}, got {result})")
all_passed = False
except Exception as e:
print(f"✗ {description}: {cmd} (exception: {e})")
all_passed = False

# Test file pattern matching
file_test_cases = [
("file-demo files=main.py", [0, None], "regex match '~.*\\.py$'"),
("file-demo files=test_file.py",
[0, None], "regex match '~test_.*'"),
("file-demo files=data.tmp", 4, "negated by '!~.*\\.tmp$'"),
("file-demo files=backup_file", 4, "negated by '!~.*_backup'"),
]

print("\nTesting file pattern matching...")
for cmd, expected_result, description in file_test_cases:
try:
result = await shell.run_interactive_line(cmd)
if isinstance(expected_result, list):
if result in expected_result:
print(f"✓ {description}: {cmd}")
else:
print(
f"✗ {description}: {cmd} (expected {expected_result}, got {result})")
all_passed = False
else:
if result == expected_result:
print(f"✓ {description}: {cmd}")
else:
print(
f"✗ {description}: {cmd} (expected {expected_result}, got {result})")
all_passed = False
except Exception as e:
print(f"✗ {description}: {cmd} (exception: {e})")
all_passed = False

return all_passed


async def main():
"""Run the integration test."""
test = PatternMatchingIntegrationTest()
success = await test.test_pattern_matching_integration()

if success:
print("\n🎉 All pattern matching integration tests passed!")
return 0
else:
print("\n❌ Some pattern matching integration tests failed!")
return 1


if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)
Loading