Skip to content

Commit cc6b425

Browse files
authored
Merge pull request #9792 from dotnet/merges/master-to-release/dev16.8
Merge master to release/dev16.8
2 parents 203e472 + a4e0947 commit cc6b425

File tree

2 files changed

+166
-13
lines changed

2 files changed

+166
-13
lines changed

src/fsharp/FSharp.Core/reflect.fs

Lines changed: 121 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ module internal Impl =
6464
| null -> None
6565
| prop -> Some(fun (obj: obj) -> prop.GetValue (obj, instancePropertyFlags ||| bindingFlags, null, null, null))
6666

67+
//-----------------------------------------------------------------
68+
// EXPRESSION TREE COMPILATION
69+
6770
let compilePropGetterFunc (prop: PropertyInfo) =
6871
let param = Expression.Parameter (typeof<obj>, "param")
6972

@@ -77,6 +80,84 @@ module internal Impl =
7780
param)
7881
expr.Compile ()
7982

83+
let compileRecordOrUnionCaseReaderFunc (typ, props: PropertyInfo[]) =
84+
let param = Expression.Parameter (typeof<obj>, "param")
85+
let typedParam = Expression.Variable typ
86+
87+
let expr =
88+
Expression.Lambda<Func<obj, obj[]>> (
89+
Expression.Block (
90+
[ typedParam ],
91+
Expression.Assign (typedParam, Expression.Convert (param, typ)),
92+
Expression.NewArrayInit (typeof<obj>, [
93+
for prop in props ->
94+
Expression.Convert (Expression.Property (typedParam, prop), typeof<obj>) :> Expression
95+
])
96+
),
97+
param)
98+
expr.Compile ()
99+
100+
let compileRecordConstructorFunc (ctorInfo: ConstructorInfo) =
101+
let ctorParams = ctorInfo.GetParameters ()
102+
let paramArray = Expression.Parameter (typeof<obj[]>, "paramArray")
103+
104+
let expr =
105+
Expression.Lambda<Func<obj[], obj>> (
106+
Expression.Convert (
107+
Expression.New (
108+
ctorInfo,
109+
[
110+
for paramIndex in 0 .. ctorParams.Length - 1 do
111+
let p = ctorParams.[paramIndex]
112+
113+
Expression.Convert (
114+
Expression.ArrayAccess (paramArray, Expression.Constant paramIndex),
115+
p.ParameterType
116+
) :> Expression
117+
]
118+
),
119+
typeof<obj>),
120+
paramArray
121+
)
122+
expr.Compile ()
123+
124+
let compileUnionCaseConstructorFunc (methodInfo: MethodInfo) =
125+
let methodParams = methodInfo.GetParameters ()
126+
let paramArray = Expression.Parameter (typeof<obj[]>, "param")
127+
128+
let expr =
129+
Expression.Lambda<Func<obj[], obj>> (
130+
Expression.Convert (
131+
Expression.Call (
132+
methodInfo,
133+
[
134+
for paramIndex in 0 .. methodParams.Length - 1 do
135+
let p = methodParams.[paramIndex]
136+
137+
Expression.Convert (
138+
Expression.ArrayAccess (paramArray, Expression.Constant paramIndex),
139+
p.ParameterType
140+
) :> Expression
141+
]
142+
),
143+
typeof<obj>),
144+
paramArray
145+
)
146+
expr.Compile ()
147+
148+
let compileUnionTagReaderFunc (info: Choice<MethodInfo, PropertyInfo>) =
149+
let param = Expression.Parameter (typeof<obj>, "param")
150+
let tag =
151+
match info with
152+
| Choice1Of2 info -> Expression.Call (info, Expression.Convert (param, info.DeclaringType)) :> Expression
153+
| Choice2Of2 info -> Expression.Property (Expression.Convert (param, info.DeclaringType), info) :> _
154+
155+
let expr =
156+
Expression.Lambda<Func<obj, int>> (
157+
tag,
158+
param)
159+
expr.Compile ()
160+
80161
//-----------------------------------------------------------------
81162
// ATTRIBUTE DECOMPILATION
82163

@@ -275,6 +356,12 @@ module internal Impl =
275356
let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags)
276357
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, bindingFlags, null, null, null)))
277358

359+
let getUnionCaseRecordReaderCompiled (typ: Type, tag: int, bindingFlags) =
360+
let props = fieldsPropsOfUnionCase (typ, tag, bindingFlags)
361+
let caseTyp = getUnionCaseTyp (typ, tag, bindingFlags)
362+
let caseTyp = if isNull caseTyp then typ else caseTyp
363+
compileRecordOrUnionCaseReaderFunc(caseTyp, props).Invoke
364+
278365
let getUnionTagReader (typ: Type, bindingFlags) : (obj -> int) =
279366
if isOptionType typ then
280367
(fun (obj: obj) -> match obj with null -> 0 | _ -> 1)
@@ -286,9 +373,22 @@ module internal Impl =
286373
match getInstancePropertyReader (typ, "Tag", bindingFlags) with
287374
| Some reader -> (fun (obj: obj) -> reader obj :?> int)
288375
| None ->
289-
(fun (obj: obj) ->
290-
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
291-
m2b.Invoke(null, [|obj|]) :?> int)
376+
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
377+
(fun (obj: obj) -> m2b.Invoke(null, [|obj|]) :?> int)
378+
379+
let getUnionTagReaderCompiled (typ: Type, bindingFlags) : (obj -> int) =
380+
if isOptionType typ then
381+
(fun (obj: obj) -> match obj with null -> 0 | _ -> 1)
382+
else
383+
let tagMap = getUnionTypeTagNameMap (typ, bindingFlags)
384+
if tagMap.Length <= 1 then
385+
(fun (_obj: obj) -> 0)
386+
else
387+
match getInstancePropertyInfo (typ, "Tag", bindingFlags) with
388+
| null ->
389+
let m2b = typ.GetMethod("GetTag", BindingFlags.Static ||| bindingFlags, null, [| typ |], null)
390+
compileUnionTagReaderFunc(Choice1Of2 m2b).Invoke
391+
| info -> compileUnionTagReaderFunc(Choice2Of2 info).Invoke
292392

293393
let getUnionTagMemberInfo (typ: Type, bindingFlags) =
294394
match getInstancePropertyInfo (typ, "Tag", bindingFlags) with
@@ -314,6 +414,10 @@ module internal Impl =
314414
(fun args ->
315415
meth.Invoke(null, BindingFlags.Static ||| BindingFlags.InvokeMethod ||| bindingFlags, null, args, null))
316416

417+
let getUnionCaseConstructorCompiled (typ: Type, tag: int, bindingFlags) =
418+
let meth = getUnionCaseConstructorMethod (typ, tag, bindingFlags)
419+
compileUnionCaseConstructorFunc(meth).Invoke
420+
317421
let checkUnionType (unionType, bindingFlags) =
318422
checkNonNull "unionType" unionType
319423
if not (isUnionType (unionType, bindingFlags)) then
@@ -599,9 +703,9 @@ module internal Impl =
599703
let props = fieldPropsOfRecordType(typ, bindingFlags)
600704
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.GetValue (obj, null)))
601705

602-
let getRecordReaderFromFuncs(typ: Type, bindingFlags) =
603-
let props = fieldPropsOfRecordType(typ, bindingFlags) |> Array.map compilePropGetterFunc
604-
(fun (obj: obj) -> props |> Array.map (fun prop -> prop.Invoke obj))
706+
let getRecordReaderCompiled(typ: Type, bindingFlags) =
707+
let props = fieldPropsOfRecordType(typ, bindingFlags)
708+
compileRecordOrUnionCaseReaderFunc(typ, props).Invoke
605709

606710
let getRecordConstructorMethod(typ: Type, bindingFlags) =
607711
let props = fieldPropsOfRecordType(typ, bindingFlags)
@@ -616,6 +720,10 @@ module internal Impl =
616720
(fun (args: obj[]) ->
617721
ctor.Invoke(BindingFlags.InvokeMethod ||| BindingFlags.Instance ||| bindingFlags, null, args, null))
618722

723+
let getRecordConstructorCompiled(typ: Type, bindingFlags) =
724+
let ctor = getRecordConstructorMethod(typ, bindingFlags)
725+
compileRecordConstructorFunc(ctor).Invoke
726+
619727
/// EXCEPTION DECOMPILATION
620728
// Check the base type - if it is also an F# type then
621729
// for the moment we know it is a Discriminated Union
@@ -817,19 +925,19 @@ type FSharpValue =
817925
invalidArg "record" (SR.GetString (SR.objIsNotARecord))
818926
getRecordReader (typ, bindingFlags) record
819927

820-
static member PreComputeRecordFieldReader(info: PropertyInfo) =
928+
static member PreComputeRecordFieldReader(info: PropertyInfo): obj -> obj =
821929
checkNonNull "info" info
822-
(fun (obj: obj) -> info.GetValue (obj, null))
930+
compilePropGetterFunc(info).Invoke
823931

824932
static member PreComputeRecordReader(recordType: Type, ?bindingFlags) : (obj -> obj[]) =
825933
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
826934
checkRecordType ("recordType", recordType, bindingFlags)
827-
getRecordReaderFromFuncs (recordType, bindingFlags)
935+
getRecordReaderCompiled (recordType, bindingFlags)
828936

829937
static member PreComputeRecordConstructor(recordType: Type, ?bindingFlags) =
830938
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
831939
checkRecordType ("recordType", recordType, bindingFlags)
832-
getRecordConstructor (recordType, bindingFlags)
940+
getRecordConstructorCompiled (recordType, bindingFlags)
833941

834942
static member PreComputeRecordConstructorInfo(recordType: Type, ?bindingFlags) =
835943
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
@@ -894,7 +1002,7 @@ type FSharpValue =
8941002
static member PreComputeUnionConstructor (unionCase: UnionCaseInfo, ?bindingFlags) =
8951003
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
8961004
checkNonNull "unionCase" unionCase
897-
getUnionCaseConstructor (unionCase.DeclaringType, unionCase.Tag, bindingFlags)
1005+
getUnionCaseConstructorCompiled (unionCase.DeclaringType, unionCase.Tag, bindingFlags)
8981006

8991007
static member PreComputeUnionConstructorInfo(unionCase: UnionCaseInfo, ?bindingFlags) =
9001008
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
@@ -926,7 +1034,7 @@ type FSharpValue =
9261034
checkNonNull "unionType" unionType
9271035
let unionType = getTypeOfReprType (unionType, bindingFlags)
9281036
checkUnionType (unionType, bindingFlags)
929-
getUnionTagReader (unionType, bindingFlags)
1037+
getUnionTagReaderCompiled (unionType, bindingFlags)
9301038

9311039
static member PreComputeUnionTagMemberInfo(unionType: Type, ?bindingFlags) =
9321040
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
@@ -939,7 +1047,7 @@ type FSharpValue =
9391047
let bindingFlags = defaultArg bindingFlags BindingFlags.Public
9401048
checkNonNull "unionCase" unionCase
9411049
let typ = unionCase.DeclaringType
942-
getUnionCaseRecordReader (typ, unionCase.Tag, bindingFlags)
1050+
getUnionCaseRecordReaderCompiled (typ, unionCase.Tag, bindingFlags)
9431051

9441052
static member GetExceptionFields (exn: obj, ?bindingFlags) =
9451053
let bindingFlags = defaultArg bindingFlags BindingFlags.Public

tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Reflection/FSharpReflection.fs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ type FSharpValueTests() =
103103
let discStructUnionCaseB = DiscStructUnionType.B(1)
104104
let discStructUnionCaseC = DiscStructUnionType.C(1.0, "stringparam")
105105

106+
let optionSome = Some(3)
107+
let optionNone: int option = None
108+
109+
let voptionSome = ValueSome("stringparam")
110+
let voptionNone: string voption = ValueNone
111+
112+
let list1 = [ 1; 2 ]
113+
let list2: int list = []
106114

107115
let fsharpDelegate1 = new FSharpDelegate(fun (x:int) -> "delegate1")
108116
let fsharpDelegate2 = new FSharpDelegate(fun (x:int) -> "delegate2")
@@ -738,6 +746,24 @@ type FSharpValueTests() =
738746
let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discUnionRecCaseB, typeof<DiscUnionType<int>>)
739747
let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo)
740748
Assert.AreEqual(discUnionReader(box(discUnionRecCaseB)) , [| box 1; box(Some(discUnionCaseB)) |])
749+
750+
// Option
751+
let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionSome, typeof<int option>)
752+
let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo)
753+
Assert.AreEqual(optionReader(box(optionSome)), [| box 3 |])
754+
755+
let (optionCaseInfo, _) = FSharpValue.GetUnionFields(optionNone, typeof<int option>)
756+
let optionReader = FSharpValue.PreComputeUnionReader(optionCaseInfo)
757+
Assert.AreEqual(optionReader(box(optionNone)), [| |])
758+
759+
// List
760+
let (listCaseInfo, _) = FSharpValue.GetUnionFields(list1, typeof<int list>)
761+
let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo)
762+
Assert.AreEqual(listReader(box(list1)), [| box 1; box [ 2 ] |])
763+
764+
let (listCaseInfo, _) = FSharpValue.GetUnionFields(list2, typeof<int list>)
765+
let listReader = FSharpValue.PreComputeUnionReader(listCaseInfo)
766+
Assert.AreEqual(listReader(box(list2)), [| |])
741767

742768
[<Test>]
743769
member __.PreComputeStructUnionReader() =
@@ -751,6 +777,15 @@ type FSharpValueTests() =
751777
let (discUnionInfo, discvaluearray) = FSharpValue.GetUnionFields(discStructUnionCaseB, typeof<DiscStructUnionType<int>>)
752778
let discUnionReader = FSharpValue.PreComputeUnionReader(discUnionInfo)
753779
Assert.AreEqual(discUnionReader(box(discStructUnionCaseB)) , [| box 1|])
780+
781+
// Value Option
782+
let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionSome, typeof<string voption>)
783+
let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo)
784+
Assert.AreEqual(voptionReader(box(voptionSome)), [| box "stringparam" |])
785+
786+
let (voptionCaseInfo, _) = FSharpValue.GetUnionFields(voptionNone, typeof<string voption>)
787+
let voptionReader = FSharpValue.PreComputeUnionReader(voptionCaseInfo)
788+
Assert.AreEqual(voptionReader(box(voptionNone)), [| |])
754789

755790
[<Test>]
756791
member __.PreComputeUnionTagMemberInfo() =
@@ -790,6 +825,16 @@ type FSharpValueTests() =
790825
// DiscUnion
791826
let discUnionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<DiscUnionType<int>>)
792827
Assert.AreEqual(discUnionTagReader(box(discUnionCaseB)), 1)
828+
829+
// Option
830+
let optionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<int option>)
831+
Assert.AreEqual(optionTagReader(box(optionSome)), 1)
832+
Assert.AreEqual(optionTagReader(box(optionNone)), 0)
833+
834+
// Value Option
835+
let voptionTagReader = FSharpValue.PreComputeUnionTagReader(typeof<string voption>)
836+
Assert.AreEqual(voptionTagReader(box(voptionSome)), 1)
837+
Assert.AreEqual(voptionTagReader(box(voptionNone)), 0)
793838

794839
// null value
795840
CheckThrowsArgumentException(fun () ->FSharpValue.PreComputeUnionTagReader(null)|> ignore)

0 commit comments

Comments
 (0)