diff --git a/src/System.CommandLine.ApiCompatibility.Tests/ApiCompatibilityApprovalTests.System_CommandLine_api_is_not_changed.approved.txt b/src/System.CommandLine.ApiCompatibility.Tests/ApiCompatibilityApprovalTests.System_CommandLine_api_is_not_changed.approved.txt index 6b42819629..01429ca129 100644 --- a/src/System.CommandLine.ApiCompatibility.Tests/ApiCompatibilityApprovalTests.System_CommandLine_api_is_not_changed.approved.txt +++ b/src/System.CommandLine.ApiCompatibility.Tests/ApiCompatibilityApprovalTests.System_CommandLine_api_is_not_changed.approved.txt @@ -66,8 +66,8 @@ System.CommandLine public static class CommandExtensions public static System.Int32 Invoke(this Command command, System.String[] args, IConsole console = null) public static System.Int32 Invoke(this Command command, System.String commandLine, IConsole console = null) - public static System.Threading.Tasks.Task InvokeAsync(this Command command, System.String[] args, IConsole console = null) - public static System.Threading.Tasks.Task InvokeAsync(this Command command, System.String commandLine, IConsole console = null) + public static System.Threading.Tasks.Task InvokeAsync(this Command command, System.String[] args, IConsole console = null, System.Threading.CancellationToken cancellationToken = null) + public static System.Threading.Tasks.Task InvokeAsync(this Command command, System.String commandLine, IConsole console = null, System.Threading.CancellationToken cancellationToken = null) public static ParseResult Parse(this Command command, System.String[] args) public static ParseResult Parse(this Command command, System.String commandLine) public class CommandLineBuilder @@ -357,8 +357,8 @@ System.CommandLine.Help System.CommandLine.Invocation public interface IInvocationResult public System.Void Apply(InvocationContext context) - public class InvocationContext - .ctor(System.CommandLine.ParseResult parseResult, System.CommandLine.IConsole console = null) + public class InvocationContext, System.IDisposable + .ctor(System.CommandLine.ParseResult parseResult, System.CommandLine.IConsole console = null, System.Threading.CancellationToken cancellationToken = null) public System.CommandLine.Binding.BindingContext BindingContext { get; } public System.CommandLine.IConsole Console { get; set; } public System.Int32 ExitCode { get; set; } @@ -368,6 +368,7 @@ System.CommandLine.Invocation public System.CommandLine.Parsing.Parser Parser { get; } public System.CommandLine.ParseResult ParseResult { get; set; } public System.Threading.CancellationToken GetCancellationToken() + public System.Void LinkToken(System.Threading.CancellationToken token) public delegate InvocationMiddleware : System.MulticastDelegate, System.ICloneable, System.Runtime.Serialization.ISerializable .ctor(System.Object object, System.IntPtr method) public System.IAsyncResult BeginInvoke(InvocationContext context, System.Func next, System.AsyncCallback callback, System.Object object) @@ -450,12 +451,12 @@ System.CommandLine.Parsing public static System.String Diagram(this System.CommandLine.ParseResult parseResult) public static System.Boolean HasOption(this System.CommandLine.ParseResult parseResult, System.CommandLine.Option option) public static System.Int32 Invoke(this System.CommandLine.ParseResult parseResult, System.CommandLine.IConsole console = null) - public static System.Threading.Tasks.Task InvokeAsync(this System.CommandLine.ParseResult parseResult, System.CommandLine.IConsole console = null) + public static System.Threading.Tasks.Task InvokeAsync(this System.CommandLine.ParseResult parseResult, System.CommandLine.IConsole console = null, System.Threading.CancellationToken cancellationToken = null) public static class ParserExtensions public static System.Int32 Invoke(this Parser parser, System.String commandLine, System.CommandLine.IConsole console = null) public static System.Int32 Invoke(this Parser parser, System.String[] args, System.CommandLine.IConsole console = null) - public static System.Threading.Tasks.Task InvokeAsync(this Parser parser, System.String commandLine, System.CommandLine.IConsole console = null) - public static System.Threading.Tasks.Task InvokeAsync(this Parser parser, System.String[] args, System.CommandLine.IConsole console = null) + public static System.Threading.Tasks.Task InvokeAsync(this Parser parser, System.String commandLine, System.CommandLine.IConsole console = null, System.Threading.CancellationToken cancellationToken = null) + public static System.Threading.Tasks.Task InvokeAsync(this Parser parser, System.String[] args, System.CommandLine.IConsole console = null, System.Threading.CancellationToken cancellationToken = null) public static System.CommandLine.ParseResult Parse(this Parser parser, System.String commandLine) public abstract class SymbolResult public System.Collections.Generic.IReadOnlyList Children { get; } diff --git a/src/System.CommandLine.Tests/Invocation/CancelOnProcessTerminationTests.cs b/src/System.CommandLine.Tests/Invocation/CancelOnProcessTerminationTests.cs index 22f6c05798..aaae8d2b2b 100644 --- a/src/System.CommandLine.Tests/Invocation/CancelOnProcessTerminationTests.cs +++ b/src/System.CommandLine.Tests/Invocation/CancelOnProcessTerminationTests.cs @@ -1,12 +1,13 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using FluentAssertions; +using System.CommandLine.Invocation; using System.CommandLine.Parsing; using System.CommandLine.Tests.Utility; using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading.Tasks; -using FluentAssertions; using Xunit; using Process = System.Diagnostics.Process; diff --git a/src/System.CommandLine.Tests/Invocation/InvocationContextTests.cs b/src/System.CommandLine.Tests/Invocation/InvocationContextTests.cs new file mode 100644 index 0000000000..5e168e7191 --- /dev/null +++ b/src/System.CommandLine.Tests/Invocation/InvocationContextTests.cs @@ -0,0 +1,64 @@ +using FluentAssertions; +using System.CommandLine; +using System.CommandLine.Invocation; +using System.CommandLine.Parsing; +using System.Threading; +using Xunit; + +namespace System.CommandLine.Tests.Invocation +{ + public class InvocationContextTests + { + [Fact] + public void InvocationContext_with_cancellation_token_returns_it() + { + using CancellationTokenSource cts = new(); + var parseResult = new CommandLineBuilder(new RootCommand()) + .Build() + .Parse(""); + using InvocationContext context = new(parseResult, cancellationToken: cts.Token); + + var token = context.GetCancellationToken(); + + token.IsCancellationRequested.Should().BeFalse(); + cts.Cancel(); + token.IsCancellationRequested.Should().BeTrue(); + } + + [Fact] + public void InvocationContext_with_linked_cancellation_token_can_cancel_by_passed_token() + { + using CancellationTokenSource cts1 = new(); + using CancellationTokenSource cts2 = new(); + var parseResult = new CommandLineBuilder(new RootCommand()) + .Build() + .Parse(""); + using InvocationContext context = new(parseResult, cancellationToken: cts1.Token); + context.LinkToken(cts2.Token); + + var token = context.GetCancellationToken(); + + token.IsCancellationRequested.Should().BeFalse(); + cts1.Cancel(); + token.IsCancellationRequested.Should().BeTrue(); + } + + [Fact] + public void InvocationContext_with_linked_cancellation_token_can_cancel_by_linked_token() + { + using CancellationTokenSource cts1 = new(); + using CancellationTokenSource cts2 = new(); + var parseResult = new CommandLineBuilder(new RootCommand()) + .Build() + .Parse(""); + using InvocationContext context = new(parseResult, cancellationToken: cts1.Token); + context.LinkToken(cts2.Token); + + var token = context.GetCancellationToken(); + + token.IsCancellationRequested.Should().BeFalse(); + cts2.Cancel(); + token.IsCancellationRequested.Should().BeTrue(); + } + } +} diff --git a/src/System.CommandLine.Tests/Invocation/InvocationExtensionsTests.cs b/src/System.CommandLine.Tests/Invocation/InvocationExtensionsTests.cs index 947f94a7c3..4f6d2e7361 100644 --- a/src/System.CommandLine.Tests/Invocation/InvocationExtensionsTests.cs +++ b/src/System.CommandLine.Tests/Invocation/InvocationExtensionsTests.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.CommandLine.Invocation; using System.CommandLine.IO; +using System.Threading; using System.Threading.Tasks; using FluentAssertions; using Xunit; @@ -149,5 +151,28 @@ public void RootCommand_Invoke_can_set_custom_result_code() resultCode.Should().Be(123); } + + [Fact] + public async Task Command_InvokeAsync_with_cancelation_token_invokes_command_handler() + { + CancellationTokenSource cts = new(); + var command = new Command("test"); + command.SetHandler((InvocationContext context) => + { + CancellationToken cancellationToken = context.GetCancellationToken(); + Assert.True(cancellationToken.CanBeCanceled); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromResult(42); + } + + return Task.FromResult(0); + }); + + cts.Cancel(); + int rv = await command.InvokeAsync("test", cancellationToken: cts.Token); + + rv.Should().Be(42); + } } } diff --git a/src/System.CommandLine.Tests/Invocation/InvocationPipelineTests.cs b/src/System.CommandLine.Tests/Invocation/InvocationPipelineTests.cs index a707ca879b..1938f9548b 100644 --- a/src/System.CommandLine.Tests/Invocation/InvocationPipelineTests.cs +++ b/src/System.CommandLine.Tests/Invocation/InvocationPipelineTests.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.CommandLine.Help; +using System.CommandLine.Invocation; using System.CommandLine.IO; using System.CommandLine.Parsing; using System.Linq; +using System.Threading; using System.Threading.Tasks; using FluentAssertions; using Xunit; @@ -44,7 +46,7 @@ public async Task InvokeAsync_chooses_the_appropriate_command() var parser = new CommandLineBuilder(new RootCommand { - first, + first, second }) .Build(); @@ -327,5 +329,39 @@ public async Task When_help_builder_factory_is_specified_it_is_used_to_create_th handlerWasCalled.Should().BeTrue(); factoryWasCalled.Should().BeTrue(); } + + [Fact] + public async Task Command_InvokeAsync_can_cancel_from_middleware() + { + var handlerWasCalled = false; + var isCancelRequested = false; + + var command = new Command("the-command"); + command.SetHandler((InvocationContext context) => + { + handlerWasCalled = true; + isCancelRequested = context.GetCancellationToken().IsCancellationRequested; + return Task.FromResult(0); + }); + + + using CancellationTokenSource cts = new(); + var parser = new CommandLineBuilder(new RootCommand + { + command + }) + .AddMiddleware(async (context, next) => + { + context.LinkToken(cts.Token); + cts.Cancel(); + await next(context); + }) + .Build(); + + await parser.InvokeAsync("the-command"); + + handlerWasCalled.Should().BeTrue(); + isCancelRequested.Should().BeTrue(); + } } } diff --git a/src/System.CommandLine.Tests/ParserTests.RootCommandAndArg0.cs b/src/System.CommandLine.Tests/ParserTests.RootCommandAndArg0.cs index 993b8c2d10..1e2819c22c 100644 --- a/src/System.CommandLine.Tests/ParserTests.RootCommandAndArg0.cs +++ b/src/System.CommandLine.Tests/ParserTests.RootCommandAndArg0.cs @@ -42,10 +42,11 @@ public void When_parsing_a_string_array_input_then_a_full_path_to_an_executable_ command.Parse(Split("inner -x hello")).Errors.Should().BeEmpty(); - command.Parse(Split($"{RootCommand.ExecutablePath} inner -x hello")) - .Errors - .Should() - .ContainSingle(e => e.Message == $"{LocalizationResources.Instance.UnrecognizedCommandOrArgument(RootCommand.ExecutablePath)}"); + var parserResult = command.Parse(Split($"\"{RootCommand.ExecutablePath}\" inner -x hello")); + parserResult + .Errors + .Should() + .ContainSingle(e => e.Message == LocalizationResources.Instance.UnrecognizedCommandOrArgument(RootCommand.ExecutablePath)); } [Fact] @@ -76,7 +77,7 @@ public void When_parsing_an_unsplit_string_then_input_a_full_path_to_an_executab } }; - var result2 = command.Parse($"{RootCommand.ExecutablePath} inner -x hello"); + var result2 = command.Parse($"\"{RootCommand.ExecutablePath}\" inner -x hello"); result2.RootCommandResult.Token.Value.Should().Be(RootCommand.ExecutablePath); } diff --git a/src/System.CommandLine/Binding/BindingContext.cs b/src/System.CommandLine/Binding/BindingContext.cs index 79cbe88a5a..13d3bd6e4b 100644 --- a/src/System.CommandLine/Binding/BindingContext.cs +++ b/src/System.CommandLine/Binding/BindingContext.cs @@ -6,8 +6,6 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; -#nullable enable - namespace System.CommandLine.Binding { /// diff --git a/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs b/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs index c20f08c841..38868cc915 100644 --- a/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs +++ b/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs @@ -52,7 +52,7 @@ public static class CommandLineBuilderExtensions /// /// The same instance of . public static CommandLineBuilder CancelOnProcessTermination( - this CommandLineBuilder builder, + this CommandLineBuilder builder, TimeSpan? timeout = null) { // https://tldp.org/LDP/abs/html/exitcodes.html - 130 - script terminated by ctrl-c @@ -65,89 +65,81 @@ public static CommandLineBuilder CancelOnProcessTermination( builder.AddMiddleware(async (context, next) => { - bool cancellationHandlingAdded = false; - ManualResetEventSlim? blockProcessExit = null; ConsoleCancelEventHandler? consoleHandler = null; EventHandler? processExitHandler = null; + ManualResetEventSlim blockProcessExit = new(initialState: false); - context.CancellationHandlingAdded += (CancellationTokenSource cts) => + processExitHandler = (_, _) => { - blockProcessExit = new ManualResetEventSlim(initialState: false); - cancellationHandlingAdded = true; - // Default limit for ProcesExit handler is 2 seconds - // https://docs.microsoft.com/en-us/dotnet/api/system.appdomain.processexit?view=net-6.0 - processExitHandler = (_, _) => + // Cancel asynchronously not to block the handler (as then the process might possibly run longer then what was the requested timeout) + Task timeoutTask = Task.Delay(timeout.Value); + Task cancelTask = Task.Factory.StartNew(context.Cancel); + + // The process exits as soon as the event handler returns. + // We provide a return value using Environment.ExitCode + // because Main will not finish executing. + // Wait for the invocation to finish. + if (!blockProcessExit.Wait(timeout > TimeSpan.Zero + ? timeout.Value + : Timeout.InfiniteTimeSpan)) { - // Cancel asynchronously not to block the handler (as then the process might possibly run longer then what was the requested timeout) - Task timeoutTask = Task.Delay(timeout.Value); - Task cancelTask = Task.Factory.StartNew(cts.Cancel); - - // The process exits as soon as the event handler returns. - // We provide a return value using Environment.ExitCode - // because Main will not finish executing. - // Wait for the invocation to finish. - if (!blockProcessExit.Wait(timeout > TimeSpan.Zero - ? timeout.Value - : Timeout.InfiniteTimeSpan)) - { - context.ExitCode = SIGINT_EXIT_CODE; - } - // Let's block here (to prevent process bailing out) for the rest of the timeout (if any), for cancellation to finish (if it hasn't yet) - else if (Task.WaitAny(timeoutTask, cancelTask) == 0) - { - // The async cancellation didn't finish in timely manner - context.ExitCode = SIGINT_EXIT_CODE; - } - ExitCode = context.ExitCode; - }; - consoleHandler = (_, args) => + context.ExitCode = SIGINT_EXIT_CODE; + } + // Let's block here (to prevent process bailing out) for the rest of the timeout (if any), for cancellation to finish (if it hasn't yet) + else if (Task.WaitAny(timeoutTask, cancelTask) == 0) { - // Stop the process from terminating. - // Since the context was cancelled, the invocation should - // finish and Main will return. - args.Cancel = true; - - // If timeout was requested - make sure cancellation processing (or any other activity within the current process) - // doesn't keep the process running after the timeout - if (timeout! > TimeSpan.Zero) - { - Task - .Delay(timeout.Value, default) - .ContinueWith(t => - { - // Prevent our ProcessExit from intervene and block the exit - AppDomain.CurrentDomain.ProcessExit -= processExitHandler; - Environment.Exit(SIGINT_EXIT_CODE); - }, (CancellationToken)default); - } + // The async cancellation didn't finish in timely manner + context.ExitCode = SIGINT_EXIT_CODE; + } + ExitCode = context.ExitCode; + }; + // Default limit for ProcesExit handler is 2 seconds + // https://docs.microsoft.com/en-us/dotnet/api/system.appdomain.processexit?view=net-6.0 + consoleHandler = (_, args) => + { + // Stop the process from terminating. + // Since the context was cancelled, the invocation should + // finish and Main will return. + args.Cancel = true; + + // If timeout was requested - make sure cancellation processing (or any other activity within the current process) + // doesn't keep the process running after the timeout + if (timeout! > TimeSpan.Zero) + { + Task + .Delay(timeout.Value, default) + .ContinueWith(t => + { + // Prevent our ProcessExit from intervene and block the exit + AppDomain.CurrentDomain.ProcessExit -= processExitHandler; + Environment.Exit(SIGINT_EXIT_CODE); + }, (CancellationToken)default); + } - // Cancel synchronously here - no need to perform it asynchronously as the timeout is already running (and would kill the process if needed), - // plus we cannot wait only on the cancellation (e.g. via `Task.Factory.StartNew(cts.Cancel).Wait(cancelationProcessingTimeout.Value)`) - // as we need to abort any other possible execution within the process - even outside the context of cancellation processing - cts.Cancel(); - }; - Console.CancelKeyPress += consoleHandler; - AppDomain.CurrentDomain.ProcessExit += processExitHandler; + // Cancel synchronously here - no need to perform it asynchronously as the timeout is already running (and would kill the process if needed), + // plus we cannot wait only on the cancellation (e.g. via `Task.Factory.StartNew(cts.Cancel).Wait(cancelationProcessingTimeout.Value)`) + // as we need to abort any other possible execution within the process - even outside the context of cancellation processing + context.Cancel(); }; + Console.CancelKeyPress += consoleHandler; + AppDomain.CurrentDomain.ProcessExit += processExitHandler; + try { await next(context); } finally { - if (cancellationHandlingAdded) - { - Console.CancelKeyPress -= consoleHandler; - AppDomain.CurrentDomain.ProcessExit -= processExitHandler; - blockProcessExit!.Set(); - } + Console.CancelKeyPress -= consoleHandler; + AppDomain.CurrentDomain.ProcessExit -= processExitHandler; + blockProcessExit?.Set(); } }, MiddlewareOrderInternal.Startup); return builder; } - + /// /// Enables the parser to recognize command line directives. /// @@ -206,7 +198,7 @@ public static CommandLineBuilder EnablePosixBundling( builder.EnablePosixBundling = value; return builder; } - + /// /// Ensures that the application is registered with the dotnet-suggest tool to enable command line completions. /// @@ -418,7 +410,7 @@ public static CommandLineBuilder UseHelp( int? maxWidth = null) { builder.CustomizeHelpLayout(customize); - + if (builder.HelpOption is null) { builder.UseHelp(new HelpOption(() => builder.LocalizationResources), maxWidth); @@ -485,7 +477,7 @@ public static CommandLineBuilder AddMiddleware( return builder; } - + /// /// Adds a middleware delegate to the invocation pipeline called before a command handler is invoked. /// @@ -599,7 +591,7 @@ public static CommandLineBuilder UseSuggestDirective( /// The maximum Levenshtein distance for suggestions based on detected typos in command line input. /// The same instance of . public static CommandLineBuilder UseTypoCorrections( - this CommandLineBuilder builder, + this CommandLineBuilder builder, int maxLevenshteinDistance = 3) { builder.AddMiddleware(async (context, next) => diff --git a/src/System.CommandLine/CommandExtensions.cs b/src/System.CommandLine/CommandExtensions.cs index 1a50cd0873..fb96cc19f8 100644 --- a/src/System.CommandLine/CommandExtensions.cs +++ b/src/System.CommandLine/CommandExtensions.cs @@ -4,6 +4,7 @@ using System.CommandLine.Invocation; using System.CommandLine.Parsing; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace System.CommandLine @@ -48,13 +49,15 @@ public static int Invoke( /// The command to invoke. /// The arguments to parse. /// The console to which output is written during invocation. + /// A token that can be used to cancel the invocation. /// The exit code for the invocation. public static async Task InvokeAsync( this Command command, string[] args, - IConsole? console = null) + IConsole? console = null, + CancellationToken cancellationToken = default) { - return await GetDefaultInvocationPipeline(command, args).InvokeAsync(console); + return await GetDefaultInvocationPipeline(command, args).InvokeAsync(console, cancellationToken); } /// @@ -64,12 +67,14 @@ public static async Task InvokeAsync( /// The command to invoke. /// The command line to parse. /// The console to which output is written during invocation. + /// A token that can be used to cancel the invocation. /// The exit code for the invocation. public static Task InvokeAsync( this Command command, string commandLine, - IConsole? console = null) => - command.InvokeAsync(CommandLineStringSplitter.Instance.Split(commandLine).ToArray(), console); + IConsole? console = null, + CancellationToken cancellationToken = default) => + command.InvokeAsync(CommandLineStringSplitter.Instance.Split(commandLine).ToArray(), console, cancellationToken); private static InvocationPipeline GetDefaultInvocationPipeline(Command command, string[] args) { diff --git a/src/System.CommandLine/Invocation/InvocationContext.cs b/src/System.CommandLine/Invocation/InvocationContext.cs index 5a14abef22..07e6e024f3 100644 --- a/src/System.CommandLine/Invocation/InvocationContext.cs +++ b/src/System.CommandLine/Invocation/InvocationContext.cs @@ -1,6 +1,7 @@ -// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Collections.Generic; using System.CommandLine.Binding; using System.CommandLine.Help; using System.CommandLine.IO; @@ -12,22 +13,32 @@ namespace System.CommandLine.Invocation /// /// Supports command invocation by providing access to parse results and other services. /// - public sealed class InvocationContext + public sealed class InvocationContext : IDisposable { - private CancellationTokenSource? _cts; - private Action? _cancellationHandlingAddedEvent; private HelpBuilder? _helpBuilder; private BindingContext? _bindingContext; private IConsole? _console; + private readonly CancellationToken _token; + private readonly LinkedList _registrations = new(); + private volatile CancellationTokenSource? _source; /// The result of the current parse operation. /// The console to which output is to be written. + /// A cancellation token that can be used to cancel and invocation. public InvocationContext( ParseResult parseResult, - IConsole? console = null) + IConsole? console = null, + CancellationToken cancellationToken = default) { ParseResult = parseResult; _console = console; + + _source = new CancellationTokenSource(); + _token = _source.Token; + if (cancellationToken.CanBeCanceled) + { + LinkToken(cancellationToken); + } } /// @@ -40,6 +51,8 @@ public BindingContext BindingContext if (_bindingContext is null) { _bindingContext = new BindingContext(this); + _bindingContext.ServiceProvider.AddService(_ => GetCancellationToken()); + _bindingContext.ServiceProvider.AddService(_ => this); } return _bindingContext; @@ -94,33 +107,30 @@ public IConsole Console /// As the is passed through the invocation pipeline to the associated with the invoked command, only the last value of this property will be the one applied. public IInvocationResult? InvocationResult { get; set; } - internal event Action CancellationHandlingAdded + /// + /// Gets a cancellation token that can be used to check if cancellation has been requested. + /// + public CancellationToken GetCancellationToken() => _token; + + internal void Cancel() { - add - { - if (_cts is not null) - { - throw new InvalidOperationException("Handlers must be added before adding cancellation handling."); - } + using var source = Interlocked.Exchange(ref _source, null); + source?.Cancel(); + } - _cancellationHandlingAddedEvent += value; - } - remove => _cancellationHandlingAddedEvent -= value; + public void LinkToken(CancellationToken token) + { + _registrations.AddLast(token.Register(Cancel)); } - /// - /// Gets token to implement cancellation handling. - /// - /// Token used by the caller to implement cancellation handling. - public CancellationToken GetCancellationToken() + /// + void IDisposable.Dispose() { - if (_cts is null) + Interlocked.Exchange(ref _source, null)?.Dispose(); + foreach (CancellationTokenRegistration registration in _registrations) { - _cts = new CancellationTokenSource(); - _cancellationHandlingAddedEvent?.Invoke(_cts); + registration.Dispose(); } - - return _cts.Token; } } } diff --git a/src/System.CommandLine/Invocation/InvocationPipeline.cs b/src/System.CommandLine/Invocation/InvocationPipeline.cs index b428beba18..ba4c1d60e0 100644 --- a/src/System.CommandLine/Invocation/InvocationPipeline.cs +++ b/src/System.CommandLine/Invocation/InvocationPipeline.cs @@ -3,22 +3,23 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace System.CommandLine.Invocation { internal class InvocationPipeline { - private readonly ParseResult parseResult; + private readonly ParseResult _parseResult; public InvocationPipeline(ParseResult parseResult) { - this.parseResult = parseResult ?? throw new ArgumentNullException(nameof(parseResult)); + _parseResult = parseResult ?? throw new ArgumentNullException(nameof(parseResult)); } - public Task InvokeAsync(IConsole? console = null) + public Task InvokeAsync(IConsole? console = null, CancellationToken cancellationToken = default) { - var context = new InvocationContext(parseResult, console); + var context = new InvocationContext(_parseResult, console, cancellationToken); if (context.Parser.Configuration.Middleware.Count == 0 && context.ParseResult.CommandResult.Command.Handler is ICommandHandler handler) @@ -40,7 +41,7 @@ static async Task FullInvocationChainAsync(InvocationContext context) public int Invoke(IConsole? console = null) { - var context = new InvocationContext(parseResult, console); + var context = new InvocationContext(_parseResult, console); if (context.Parser.Configuration.Middleware.Count == 0 && context.ParseResult.CommandResult.Command.Handler is ICommandHandler handler) diff --git a/src/System.CommandLine/Invocation/ServiceProvider.cs b/src/System.CommandLine/Invocation/ServiceProvider.cs index acf3091947..852a1c9913 100644 --- a/src/System.CommandLine/Invocation/ServiceProvider.cs +++ b/src/System.CommandLine/Invocation/ServiceProvider.cs @@ -6,8 +6,6 @@ using System.CommandLine.Help; using System.Threading; -#nullable enable - namespace System.CommandLine.Invocation { internal class ServiceProvider : IServiceProvider diff --git a/src/System.CommandLine/Parsing/ParseResultExtensions.cs b/src/System.CommandLine/Parsing/ParseResultExtensions.cs index 0f233fe4f7..7cd3ebb8d6 100644 --- a/src/System.CommandLine/Parsing/ParseResultExtensions.cs +++ b/src/System.CommandLine/Parsing/ParseResultExtensions.cs @@ -6,6 +6,7 @@ using System.CommandLine.Invocation; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace System.CommandLine.Parsing @@ -20,11 +21,13 @@ public static class ParseResultExtensions /// /// A parse result on which the invocation is based. /// A console to which output can be written. By default, is used. + /// A token that can be used to cancel an invocation. /// A task whose result can be used as a process exit code. public static async Task InvokeAsync( this ParseResult parseResult, - IConsole? console = null) => - await new InvocationPipeline(parseResult).InvokeAsync(console); + IConsole? console = null, + CancellationToken cancellationToken = default) => + await new InvocationPipeline(parseResult).InvokeAsync(console, cancellationToken); /// /// Invokes the appropriate command handler for a parsed command line input. diff --git a/src/System.CommandLine/Parsing/ParserExtensions.cs b/src/System.CommandLine/Parsing/ParserExtensions.cs index 0f51f66ec7..4c90a1783b 100644 --- a/src/System.CommandLine/Parsing/ParserExtensions.cs +++ b/src/System.CommandLine/Parsing/ParserExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace System.CommandLine.Parsing @@ -40,8 +41,9 @@ public static int Invoke( public static Task InvokeAsync( this Parser parser, string commandLine, - IConsole? console = null) => - parser.InvokeAsync(CommandLineStringSplitter.Instance.Split(commandLine).ToArray(), console); + IConsole? console = null, + CancellationToken cancellationToken = default) => + parser.InvokeAsync(CommandLineStringSplitter.Instance.Split(commandLine).ToArray(), console, cancellationToken); /// /// Parses a command line string array and invokes the handler for the indicated command. @@ -50,8 +52,9 @@ public static Task InvokeAsync( public static async Task InvokeAsync( this Parser parser, string[] args, - IConsole? console = null) => - await parser.Parse(args).InvokeAsync(console); + IConsole? console = null, + CancellationToken cancellationToken = default) => + await parser.Parse(args).InvokeAsync(console, cancellationToken); /// /// Parses a command line string.