diff --git a/CSharpHTTPClient/Client.cs b/CSharpHTTPClient/Client.cs index f481771..94d954c 100644 --- a/CSharpHTTPClient/Client.cs +++ b/CSharpHTTPClient/Client.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using System.Web.Script.Serialization; using System.Web; +using System.Threading; namespace SendGrid.CSharp.HTTP.Client { @@ -266,6 +267,7 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o if( Enum.IsDefined(typeof(Methods), binder.Name.ToUpper())) { + CancellationToken cancellationToken = CancellationToken.None; string queryParams = null; string requestBody = null; int i = 0; @@ -286,9 +288,13 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o { AddRequestHeader((Dictionary)obj); } + else if (name == "cancellationToken") + { + cancellationToken = (CancellationToken)obj; + } i++; } - result = RequestAsync(binder.Name.ToUpper(), requestBody: requestBody, queryParams: queryParams).ConfigureAwait(false); + result = RequestAsync(binder.Name.ToUpper(), requestBody: requestBody, queryParams: queryParams, cancellationToken: cancellationToken).ConfigureAwait(false); return true; } else @@ -304,11 +310,12 @@ public override bool TryInvokeMember(InvokeMemberBinder binder, object[] args, o /// /// Client object ready for communication with API /// The parameters for the API call + /// A token that allows cancellation of the http request /// Response object - public async virtual Task MakeRequest(HttpClient client, HttpRequestMessage request) + public async virtual Task MakeRequest(HttpClient client, HttpRequestMessage request, CancellationToken cancellationToken = default(CancellationToken)) { - HttpResponseMessage response = await client.SendAsync(request).ConfigureAwait(false); + HttpResponseMessage response = await client.SendAsync(request, cancellationToken).ConfigureAwait(false); return new Response(response.StatusCode, response.Content, response.Headers); } @@ -316,10 +323,11 @@ public async virtual Task MakeRequest(HttpClient client, HttpRequestMe /// Prepare for async call to the API server /// /// HTTP verb + /// A token that allows cancellation of the http request /// JSON formatted string /// JSON formatted queary paramaters /// Response object - private async Task RequestAsync(string method, String requestBody = null, String queryParams = null) + private async Task RequestAsync(string method, String requestBody = null, String queryParams = null, CancellationToken cancellationToken = default(CancellationToken)) { using (var client = new HttpClient()) { @@ -367,9 +375,13 @@ private async Task RequestAsync(string method, String requestBody = nu RequestUri = new Uri(endpoint), Content = content }; - return await MakeRequest(client, request).ConfigureAwait(false); + return await MakeRequest(client, request, cancellationToken).ConfigureAwait(false); } + catch(TaskCanceledException) + { + throw; + } catch (Exception ex) { HttpResponseMessage response = new HttpResponseMessage(); diff --git a/UnitTest/UnitTest.cs b/UnitTest/UnitTest.cs index f8c1a2c..d0a224a 100644 --- a/UnitTest/UnitTest.cs +++ b/UnitTest/UnitTest.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Text; using System.Net; +using System.Threading; namespace UnitTest { @@ -17,12 +18,19 @@ public MockClient(string host, Dictionary requestHeaders = null, { } - public async override Task MakeRequest(HttpClient client, HttpRequestMessage request) + public override Task MakeRequest(HttpClient client, HttpRequestMessage request, CancellationToken cancellationToken) { - HttpResponseMessage response = new HttpResponseMessage(); - response.Content = new StringContent("{'test': 'test_content'}", Encoding.UTF8, "application/json"); - response.StatusCode = HttpStatusCode.OK; - return new Response(response.StatusCode, response.Content, response.Headers); + return Task.Factory.StartNew(() => + { + + HttpResponseMessage response = new HttpResponseMessage(); + response.Content = new StringContent("{'test': 'test_content'}", Encoding.UTF8, "application/json"); + response.StatusCode = HttpStatusCode.OK; + + cancellationToken.ThrowIfCancellationRequested(); + + return new Response(response.StatusCode, response.Content, response.Headers); + }, cancellationToken); } } @@ -74,5 +82,17 @@ public async void TestMethodCall() var content = new StringContent("{'test': 'test_content'}", Encoding.UTF8, "application/json"); Assert.AreEqual(response.Body.ReadAsStringAsync().Result, content.ReadAsStringAsync().Result); } + + [Test] + [ExpectedException(typeof(TaskCanceledException))] + public async void TestMethodCallWithCancellationToken() + { + var cancellationTokenSource = new CancellationTokenSource(); + cancellationTokenSource.Cancel(); + + var host = "http://api.test.com"; + dynamic test_client = new MockClient(host: host); + Response response = await test_client.get(cancellationToken: cancellationTokenSource.Token); + } } }