Skip to content

Commit e24e0bc

Browse files
committed
Address PR feedback
1 parent f8c8829 commit e24e0bc

File tree

8 files changed

+108
-35
lines changed

8 files changed

+108
-35
lines changed

src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ internal partial class Http1Connection : HttpProtocol, IRequestProcessor
3737
private int _remainingRequestHeadersBytesAllowed;
3838

3939
public Http1Connection(HttpConnectionContext context)
40-
: base(context, isHttp1: true)
40+
: base(context)
4141
{
4242
_context = context;
4343
_parser = ServiceContext.HttpParser;

src/Servers/Kestrel/Core/src/Internal/Http/HttpHeaders.Generated.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6105,7 +6105,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
61056105
}
61066106

61076107
// We didn't have a previous matching header value, or have already added a header, so get the string for this value.
6108-
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1);
6108+
var valueStr = value.GetRequestHeaderString(_useLatin1);
61096109
if ((_bits & flag) == 0)
61106110
{
61116111
// We didn't already have a header set, so add a new one.
@@ -6123,7 +6123,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
61236123
// The header was not one of the "known" headers.
61246124
// Convert value to string first, because passing two spans causes 8 bytes stack zeroing in
61256125
// this method with rep stosd, which is slower than necessary.
6126-
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1);
6126+
var valueStr = value.GetRequestHeaderString(_useLatin1);
61276127
AppendUnknownHeaders(name, valueStr);
61286128
}
61296129
}

src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,15 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
7272
private Stream _requestStreamInternal;
7373
private Stream _responseStreamInternal;
7474

75-
private readonly bool _useLatin1;
76-
77-
public HttpProtocol(HttpConnectionContext context, bool isHttp1)
75+
public HttpProtocol(HttpConnectionContext context)
7876
{
7977
_context = context;
8078

8179
ServerOptions = ServiceContext.ServerOptions;
82-
_useLatin1 = isHttp1 && ServerOptions.Latin1RequestHeaders;
8380

8481
HttpRequestHeaders = new HttpRequestHeaders(
8582
reuseHeaderValues: !ServerOptions.DisableStringReuse,
86-
useLatin1: _useLatin1);
83+
useLatin1: ServerOptions.Latin1RequestHeaders);
8784

8885
HttpResponseControl = this;
8986
}
@@ -520,7 +517,7 @@ public void OnTrailer(Span<byte> name, Span<byte> value)
520517
}
521518

522519
string key = name.GetHeaderName();
523-
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1);
520+
var valueStr = value.GetRequestHeaderString(ServerOptions.Latin1RequestHeaders);
524521
RequestTrailers.Append(key, valueStr);
525522
}
526523

src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private void AppendContentLength(Span<byte> value)
8787
parsed < 0 ||
8888
consumed != value.Length)
8989
{
90-
BadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1));
90+
BadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetRequestHeaderString(_useLatin1));
9191
}
9292

9393
_contentLength = parsed;

src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ internal abstract partial class Http2Stream : HttpProtocol, IThreadPoolWorkItem
3333
private readonly object _completionLock = new object();
3434

3535
public Http2Stream(Http2StreamContext context)
36-
: base(context, isHttp1: false)
36+
: base(context)
3737
{
3838
_context = context;
3939

src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,7 @@ public static unsafe string GetAsciiStringNonNullCharacters(this Span<byte> span
130130
return asciiString;
131131
}
132132

133-
public static string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
134-
{
135-
return GetAsciiOrUTF8StringNonNullCharacters(span, useLatin1: false);
136-
}
137-
138-
public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span, bool useLatin1)
133+
public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
139134
{
140135
if (span.IsEmpty)
141136
{
@@ -147,7 +142,7 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte
147142
fixed (char* output = resultString)
148143
fixed (byte* buffer = span)
149144
{
150-
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
145+
// StringUtilities.TryGetAsciiString returns null if there are any null (0 byte) characters
151146
// in the string
152147
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
153148
{
@@ -157,27 +152,55 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte
157152
throw new InvalidOperationException();
158153
}
159154

160-
if (useLatin1)
155+
try
161156
{
162-
StringUtilities.GetLatin1String(buffer, output, span.Length);
157+
resultString = HeaderValueEncoding.GetString(buffer, span.Length);
163158
}
164-
else
159+
catch (DecoderFallbackException)
165160
{
166-
try
167-
{
168-
resultString = HeaderValueEncoding.GetString(buffer, span.Length);
169-
}
170-
catch (DecoderFallbackException)
171-
{
172-
throw new InvalidOperationException();
173-
}
161+
throw new InvalidOperationException();
174162
}
175163
}
176164
}
177165

178166
return resultString;
179167
}
180168

169+
public static unsafe string GetLatin1StringNonNullCharacters(this Span<byte> span)
170+
{
171+
if (span.IsEmpty)
172+
{
173+
return string.Empty;
174+
}
175+
176+
var resultString = new string('\0', span.Length);
177+
178+
fixed (char* output = resultString)
179+
fixed (byte* buffer = span)
180+
{
181+
// This returns false if there are any null (0 byte) characters in the string.
182+
if (!StringUtilities.TryGetLatin1String(buffer, output, span.Length))
183+
{
184+
// null characters are considered invalid
185+
throw new InvalidOperationException();
186+
}
187+
}
188+
189+
return resultString;
190+
}
191+
192+
public static unsafe string GetRequestHeaderString(this Span<byte> span, bool useLatin1)
193+
{
194+
if (useLatin1)
195+
{
196+
return GetLatin1StringNonNullCharacters(span);
197+
}
198+
else
199+
{
200+
return GetAsciiOrUTF8StringNonNullCharacters(span);
201+
}
202+
}
203+
181204
public static string GetAsciiStringEscaped(this Span<byte> span, int maxChars)
182205
{
183206
var sb = new StringBuilder();

src/Servers/Kestrel/Core/src/Internal/Infrastructure/StringUtilities.cs

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5-
using System.Buffers.Binary;
65
using System.Diagnostics;
76
using System.Numerics;
87
using System.Runtime.CompilerServices;
@@ -17,6 +16,9 @@ internal class StringUtilities
1716
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
1817
public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
1918
{
19+
Debug.Assert(input != null);
20+
Debug.Assert(output != null);
21+
2022
// Calculate end position
2123
var end = input + count;
2224
// Start as valid
@@ -116,10 +118,15 @@ out Unsafe.AsRef<Vector<short>>(output),
116118
}
117119

118120
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
119-
public static unsafe void GetLatin1String(byte* input, char* output, int count)
121+
public static unsafe bool TryGetLatin1String(byte* input, char* output, int count)
120122
{
123+
Debug.Assert(input != null);
124+
Debug.Assert(output != null);
125+
121126
// Calculate end position
122127
var end = input + count;
128+
// Start as valid
129+
var isValid = true;
123130

124131
do
125132
{
@@ -131,6 +138,8 @@ public static unsafe void GetLatin1String(byte* input, char* output, int count)
131138
// 64-bit: Loop longs by default
132139
while (input <= end - sizeof(long))
133140
{
141+
isValid &= CheckBytesNotNull(((long*)input)[0]);
142+
134143
output[0] = (char)input[0];
135144
output[1] = (char)input[1];
136145
output[2] = (char)input[2];
@@ -145,6 +154,8 @@ public static unsafe void GetLatin1String(byte* input, char* output, int count)
145154
}
146155
if (input <= end - sizeof(int))
147156
{
157+
isValid &= CheckBytesNotNull(((int*)input)[0]);
158+
148159
output[0] = (char)input[0];
149160
output[1] = (char)input[1];
150161
output[2] = (char)input[2];
@@ -159,6 +170,8 @@ public static unsafe void GetLatin1String(byte* input, char* output, int count)
159170
// 32-bit: Loop ints by default
160171
while (input <= end - sizeof(int))
161172
{
173+
isValid &= CheckBytesNotNull(((int*)input)[0]);
174+
162175
output[0] = (char)input[0];
163176
output[1] = (char)input[1];
164177
output[2] = (char)input[2];
@@ -170,6 +183,8 @@ public static unsafe void GetLatin1String(byte* input, char* output, int count)
170183
}
171184
if (input <= end - sizeof(short))
172185
{
186+
isValid &= CheckBytesNotNull(((short*)input)[0]);
187+
173188
output[0] = (char)input[0];
174189
output[1] = (char)input[1];
175190

@@ -178,16 +193,18 @@ public static unsafe void GetLatin1String(byte* input, char* output, int count)
178193
}
179194
if (input < end)
180195
{
196+
isValid &= CheckBytesNotNull(((sbyte*)input)[0]);
181197
output[0] = (char)input[0];
182198
}
183199

184-
return;
200+
return isValid;
185201
}
186202

187203
// do/while as entry condition already checked
188204
do
189205
{
190206
var vector = Unsafe.AsRef<Vector<sbyte>>(input);
207+
isValid &= CheckBytesNotNull(vector);
191208
Vector.Widen(
192209
vector,
193210
out Unsafe.AsRef<Vector<short>>(output),
@@ -200,6 +217,8 @@ out Unsafe.AsRef<Vector<short>>(output),
200217
// Vector path done, loop back to do non-Vector
201218
// If is a exact multiple of vector size, bail now
202219
} while (input < end);
220+
221+
return isValid;
203222
}
204223

205224
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
@@ -508,7 +527,7 @@ private static bool CheckBytesInAsciiRange(Vector<sbyte> check)
508527
// Validate: bytes != 0 && bytes <= 127
509528
// Subtract 1 from all bytes to move 0 to high bits
510529
// bitwise or with self to catch all > 127 bytes
511-
// mask off high bits and check if 0
530+
// mask off non high bits and check if 0
512531

513532
[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
514533
private static bool CheckBytesInAsciiRange(long check)
@@ -531,5 +550,39 @@ private static bool CheckBytesInAsciiRange(short check)
531550

532551
private static bool CheckBytesInAsciiRange(sbyte check)
533552
=> check > 0;
553+
554+
[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
555+
private static bool CheckBytesNotNull(Vector<sbyte> check)
556+
{
557+
// Vectorized byte range check, signed byte != null
558+
return !Vector.EqualsAny(check, Vector<sbyte>.Zero);
559+
}
560+
561+
// Validate: bytes != 0
562+
// Subtract 1 from all bytes to move 0 to high bits
563+
// bitwise and with ~check so high bits are only set for bytes that were originally 0
564+
// mask off non high bits and check if 0
565+
566+
[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
567+
private static bool CheckBytesNotNull(long check)
568+
{
569+
const long HighBits = unchecked((long)0x8080808080808080L);
570+
return ((check - 0x0101010101010101L) & ~check & HighBits) == 0;
571+
}
572+
573+
private static bool CheckBytesNotNull(int check)
574+
{
575+
const int HighBits = unchecked((int)0x80808080);
576+
return ((check - 0x01010101) & ~check & HighBits) == 0;
577+
}
578+
579+
private static bool CheckBytesNotNull(short check)
580+
{
581+
const short HighBits = unchecked((short)0x8080);
582+
return ((check - 0x0101) & ~check & HighBits) == 0;
583+
}
584+
585+
private static bool CheckBytesNotNull(sbyte check)
586+
=> check != 0;
534587
}
535588
}

src/Servers/Kestrel/shared/KnownHeaders.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
985985
}}
986986
987987
// We didn't have a previous matching header value, or have already added a header, so get the string for this value.
988-
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1);
988+
var valueStr = value.GetRequestHeaderString(_useLatin1);
989989
if ((_bits & flag) == 0)
990990
{{
991991
// We didn't already have a header set, so add a new one.
@@ -1003,7 +1003,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
10031003
// The header was not one of the ""known"" headers.
10041004
// Convert value to string first, because passing two spans causes 8 bytes stack zeroing in
10051005
// this method with rep stosd, which is slower than necessary.
1006-
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters(_useLatin1);
1006+
var valueStr = value.GetRequestHeaderString(_useLatin1);
10071007
AppendUnknownHeaders(name, valueStr);
10081008
}}
10091009
}}" : "")}

0 commit comments

Comments
 (0)