From e8ea5c4d9c299184ac03129d4a73231375e870a6 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 18 Jul 2018 18:45:11 -0700 Subject: [PATCH 01/10] Don't fail in case of const field in Collection source. Extend support for basic C# types for DataVIew<->collection conversion. --- src/Microsoft.ML.Api/ApiUtils.cs | 7 +- .../DataViewConstructionUtils.cs | 175 +++++++++----- src/Microsoft.ML.Api/SchemaDefinition.cs | 3 +- src/Microsoft.ML.Api/TypedCursor.cs | 144 +++++++++--- src/Microsoft.ML.Core/Data/DataKind.cs | 214 +++++++++--------- .../CollectionDataSourceTests.cs | 203 +++++++++++++++++ 6 files changed, 538 insertions(+), 208 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 5b5936b5a8..8b85c24d63 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -21,9 +21,10 @@ private static OpCode GetAssignmentOpCode(Type t) // REVIEW: This should be a Dictionary based solution. // DvTexts, strings, arrays, and VBuffers. if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || - t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray || - (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) || - t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) + t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || + (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || + (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || + t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) { return OpCodes.Stobj; } diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index f185dc6c0b..962481f51b 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -119,7 +119,7 @@ private Delegate CreateGetter(int index) var column = DataView._schema.SchemaDefn.Columns[index]; var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType; - + var genericType = outputType; Func del; if (outputType.IsArray) @@ -129,20 +129,48 @@ private Delegate CreateGetter(int index) if (outputType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateStringArrayToVBufferGetter(index); + return CreateArrayGetterDelegate(index, (x) => new DvText(x)); + } + else if (outputType.GetElementType() == typeof(int)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateArrayGetterDelegate(index, (x) => x); + } + else if (outputType.GetElementType() == typeof(long)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateArrayGetterDelegate(index, (x) => x); + } + else if (outputType.GetElementType() == typeof(short)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateArrayGetterDelegate(index, (x) => x); } + else if (outputType.GetElementType() == typeof(sbyte)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateArrayGetterDelegate(index, (x) => x); + } + else if (outputType.GetElementType() == typeof(bool)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateArrayGetterDelegate(index, (x)=>x); + } + // T[] -> VBuffer Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType); - del = CreateArrayToVBufferGetter; + del = CreateDirectArrayGetterDelegate; + genericType = outputType.GetElementType(); } else if (colType.IsVector) { // VBuffer -> VBuffer // REVIEW: Do we care about accomodating VBuffer -> VBuffer? + // REVIEW: why it's int and not long? Ch.Assert(outputType.IsGenericType); Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType); - del = CreateVBufferToVBufferDelegate; + del = CreateDirectVBufferGetterDelegate; } else if (colType.IsPrimitive) { @@ -150,24 +178,74 @@ private Delegate CreateGetter(int index) { // String -> DvText Ch.Assert(colType.IsText); - return CreateStringToTextGetter(index); + return CreateGetterDelegate(index, (x) => x == null ? DvText.NA : new DvText(x)); } else if (outputType == typeof(bool)) { // Bool -> DvBool Ch.Assert(colType.IsBool); - return CreateBooleanToDvBoolGetter(index); + return CreateGetterDelegate(index, (x) => x); } else if (outputType == typeof(bool?)) { // Bool? -> DvBool Ch.Assert(colType.IsBool); - return CreateNullableBooleanToDvBoolGetter(index); + return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvBool.NA); + } + else if (outputType == typeof(int)) + { + // int -> DvInt4 + Ch.Assert(colType == NumberType.I4); + return CreateGetterDelegate(index, (x) => x); + } + else if (outputType == typeof(int?)) + { + // int -> DvInt4 + Ch.Assert(colType == NumberType.I4); + return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt4.NA); + } + else if (outputType == typeof(short)) + { + // short -> DvInt2 + Ch.Assert(colType == NumberType.I2); + return CreateGetterDelegate(index, (x) => x); + } + else if (outputType == typeof(short?)) + { + // short? -> DvInt2 + Ch.Assert(colType == NumberType.I2); + return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt2.NA); + } + else if (outputType == typeof(long)) + { + // long -> DvInt8 + Ch.Assert(colType == NumberType.I8); + return CreateGetterDelegate(index, (x) => x); + } + else if (outputType == typeof(long?)) + { + // long? -> DvInt8 + Ch.Assert(colType == NumberType.I8); + return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt8.NA); + } + else if (outputType == typeof(sbyte)) + { + // sbyte -> DvInt1 + Ch.Assert(colType == NumberType.I1); + return CreateGetterDelegate(index, (x) => (DvInt1)x); + } + else if (outputType == typeof(sbyte?)) + { + // sbyte -> DvInt1 + Ch.Assert(colType == NumberType.I1); + return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt1.NA); } - // T -> T - Ch.Assert(colType.RawType == outputType); - del = CreateDirectGetter; + if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType)); + else + Ch.Assert(colType.RawType == outputType); + del = CreateDirectGetterDelegate; } else { @@ -175,66 +253,40 @@ private Delegate CreateGetter(int index) throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", outputType.FullName); } MethodInfo meth = - del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType); + del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Delegate)meth.Invoke(this, new object[] { index }); } - private Delegate CreateStringArrayToVBufferGetter(int index) + private Delegate CreateArrayGetterDelegate(int index, Func convert) { - var peek = DataView._peeks[index] as Peek; + var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); - - string[] buf = null; - - return (ValueGetter>)((ref VBuffer dst) => + TSrc[] buf = default; + return (ValueGetter>)((ref VBuffer dst) => { peek(GetCurrentRowObject(), Position, ref buf); var n = Utils.Size(buf); - dst = new VBuffer(n, Utils.Size(dst.Values) < n - ? new DvText[n] + dst = new VBuffer(n, Utils.Size(dst.Values) < n + ? new TDst[n] : dst.Values, dst.Indices); for (int i = 0; i < n; i++) - dst.Values[i] = new DvText(buf[i]); + dst.Values[i] = convert(buf[i]); }); } - private Delegate CreateStringToTextGetter(int index) + private Delegate CreateGetterDelegate(int index, Func convert) { - var peek = DataView._peeks[index] as Peek; + var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); - string buf = null; - return (ValueGetter)((ref DvText dst) => - { - peek(GetCurrentRowObject(), Position, ref buf); - dst = new DvText(buf); - }); - } - - private Delegate CreateBooleanToDvBoolGetter(int index) - { - var peek = DataView._peeks[index] as Peek; - Ch.AssertValue(peek); - bool buf = false; - return (ValueGetter)((ref DvBool dst) => + TSrc buf = default; + return (ValueGetter)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref buf); - dst = (DvBool)buf; + dst = convert(buf); }); } - private Delegate CreateNullableBooleanToDvBoolGetter(int index) - { - var peek = DataView._peeks[index] as Peek; - Ch.AssertValue(peek); - bool? buf = null; - return (ValueGetter)((ref DvBool dst) => - { - peek(GetCurrentRowObject(), Position, ref buf); - dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA; - }); - } - - private Delegate CreateArrayToVBufferGetter(int index) + private Delegate CreateDirectArrayGetterDelegate(int index) { var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); @@ -250,26 +302,29 @@ private Delegate CreateArrayToVBufferGetter(int index) }); } - private Delegate CreateVBufferToVBufferDelegate(int index) + private Delegate CreateDirectVBufferGetterDelegate(int index) { var peek = DataView._peeks[index] as Peek>; Ch.AssertValue(peek); VBuffer buf = default(VBuffer); return (ValueGetter>)((ref VBuffer dst) => - { - // The peek for a VBuffer is just a simple assignment, so there is - // no copy going on in the peek, so we must do that as a second - // step to the destination. - peek(GetCurrentRowObject(), Position, ref buf); - buf.CopyTo(ref dst); - }); + { + // The peek for a VBuffer is just a simple assignment, so there is + // no copy going on in the peek, so we must do that as a second + // step to the destination. + peek(GetCurrentRowObject(), Position, ref buf); + buf.CopyTo(ref dst); + }); } - private Delegate CreateDirectGetter(int index) + private Delegate CreateDirectGetterDelegate(int index) { var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); - return (ValueGetter)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref dst); }); + return (ValueGetter)((ref TDst dst) => + { + peek(GetCurrentRowObject(), Position, ref dst); + }); } protected abstract TRow GetCurrentRowObject(); diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index 5f84712625..522fb3510c 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -327,7 +327,8 @@ public static SchemaDefinition Create(Type userType) // This field does not need a column. // REVIEW: maybe validate the channel attribute now, instead // of later at cursor creation. - if (fieldInfo.FieldType == typeof(IChannel)) + // Const fields not need to be mapped. + if (fieldInfo.FieldType == typeof(IChannel) || fieldInfo.IsLiteral) continue; if (fieldInfo.GetCustomAttribute() != null) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 29fee77a02..31a4c99bf3 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -271,7 +271,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit var colType = input.Schema.GetColumnType(index); var fieldInfo = column.FieldInfo; var fieldType = fieldInfo.FieldType; - + var genericType = fieldType; Func> del; if (fieldType.IsArray) { @@ -280,11 +280,38 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit if (fieldType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateVBufferToStringArraySetter(input, index, poke, peek); + return CreateVBufferSetter(input, index, poke, peek, (x) => x.ToString()); + } + else if (fieldType.GetElementType() == typeof(bool)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateVBufferSetter(input, index, poke, peek, (x) => Convert.ToBoolean(x.RawValue)); + } + else if (fieldType.GetElementType() == typeof(int)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + } + else if (fieldType.GetElementType() == typeof(short)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); } + else if (fieldType.GetElementType() == typeof(long)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + } + else if (fieldType.GetElementType() == typeof(sbyte)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + } + // VBuffer -> T[] Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType); - del = CreateVBufferToArraySetter; + del = CreateVBufferDirectSetter; + genericType = fieldType.GetElementType(); } else if (colType.IsVector) { @@ -302,53 +329,108 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit // DvText -> String Ch.Assert(colType.IsText); Ch.Assert(peek == null); - return CreateTextToStringSetter(input, index, poke); + return CreateActionSetter(input, index, poke, (x) => x.ToString()); } else if (fieldType == typeof(bool)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateDvBoolToBoolSetter(input, index, poke); + return CreateActionSetter(input, index, poke, (x) => Convert.ToBoolean(x.RawValue)); } - else + else if (fieldType == typeof(bool?)) { - // T -> T - Ch.Assert(colType.RawType == fieldType); - del = CreateDirectSetter; + Ch.Assert(colType.IsBool); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (bool?)null : (x.IsTrue ? true : false)); + } + else if (fieldType == typeof(int)) + { + Ch.Assert(colType == NumberType.I4); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.RawValue); + } + else if (fieldType == typeof(int?)) + { + Ch.Assert(colType == NumberType.I4); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (int?)null : x.RawValue); + } + else if (fieldType == typeof(short)) + { + Ch.Assert(colType == NumberType.I2); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.RawValue); + } + else if (fieldType == typeof(short?)) + { + Ch.Assert(colType == NumberType.I2); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (short?)null : x.RawValue); } + else if (fieldType == typeof(long)) + { + Ch.Assert(colType == NumberType.I8); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.RawValue); + } + else if (fieldType == typeof(long?)) + { + Ch.Assert(colType == NumberType.I8); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (long?)null : x.RawValue); + } + else if (fieldType == typeof(sbyte)) + { + Ch.Assert(colType == NumberType.I1); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.RawValue); + } + else if (fieldType == typeof(sbyte?)) + { + Ch.Assert(colType == NumberType.I1); + Ch.Assert(peek == null); + return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (sbyte?)null : x.RawValue); + } + // T -> T + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType)); + else + Ch.Assert(colType.RawType == fieldType); + + del = CreateDirectSetter; } else { // REVIEW: Is this even possible? throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName); } - MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType); + MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Action)meth.Invoke(this, new object[] { input, index, poke, peek }); } - private Action CreateVBufferToStringArraySetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) { - var getter = input.GetGetter>(col); - var typedPoke = poke as Poke; - var typedPeek = peek as Peek; + var getter = input.GetGetter>(col); + var typedPoke = poke as Poke; + var typedPeek = peek as Peek; Contracts.AssertValue(typedPoke); Contracts.AssertValue(typedPeek); - VBuffer value = default(VBuffer); - string[] buf = null; + VBuffer value = default; + TDst[] buf = null; return row => { getter(ref value); typedPeek(row, Position, ref buf); if (Utils.Size(buf) != value.Length) - buf = new string[value.Length]; + buf = new TDst[value.Length]; foreach (var pair in value.Items(true)) - buf[pair.Key] = pair.Value.ToString(); + buf[pair.Key] = convert(pair.Value); typedPoke(row, buf); }; } - private Action CreateVBufferToArraySetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateVBufferDirectSetter(IRow input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -386,29 +468,17 @@ private Action CreateVBufferToArraySetter(IRow input, int col, Deleg }; } - private static Action CreateTextToStringSetter(IRow input, int col, Delegate poke) - { - var getter = input.GetGetter(col); - var typedPoke = poke as Poke; - Contracts.AssertValue(typedPoke); - DvText value = default(DvText); - return row => - { - getter(ref value); - typedPoke(row, value.ToString()); - }; - } - - private static Action CreateDvBoolToBoolSetter(IRow input, int col, Delegate poke) + private static Action CreateActionSetter(IRow input, int col, Delegate poke, Func convert) { - var getter = input.GetGetter(col); - var typedPoke = poke as Poke; + var getter = input.GetGetter(col); + var typedPoke = poke as Poke; Contracts.AssertValue(typedPoke); - DvBool value = default(DvBool); + TSrc value = default; return row => { getter(ref value); - typedPoke(row, Convert.ToBoolean(value.RawValue)); + var toPoke = convert(value); + typedPoke(row, toPoke); }; } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 5ed5ded1c1..358227399b 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -83,22 +83,22 @@ public static ulong ToMaxInt(this DataKind kind) { switch (kind) { - case DataKind.I1: - return (ulong)sbyte.MaxValue; - case DataKind.U1: - return byte.MaxValue; - case DataKind.I2: - return (ulong)short.MaxValue; - case DataKind.U2: - return ushort.MaxValue; - case DataKind.I4: - return int.MaxValue; - case DataKind.U4: - return uint.MaxValue; - case DataKind.I8: - return long.MaxValue; - case DataKind.U8: - return ulong.MaxValue; + case DataKind.I1: + return (ulong)sbyte.MaxValue; + case DataKind.U1: + return byte.MaxValue; + case DataKind.I2: + return (ulong)short.MaxValue; + case DataKind.U2: + return ushort.MaxValue; + case DataKind.I4: + return int.MaxValue; + case DataKind.U4: + return uint.MaxValue; + case DataKind.I8: + return long.MaxValue; + case DataKind.U8: + return ulong.MaxValue; } return 0; @@ -112,22 +112,22 @@ public static long ToMinInt(this DataKind kind) { switch (kind) { - case DataKind.I1: - return sbyte.MinValue; - case DataKind.U1: - return byte.MinValue; - case DataKind.I2: - return short.MinValue; - case DataKind.U2: - return ushort.MinValue; - case DataKind.I4: - return int.MinValue; - case DataKind.U4: - return uint.MinValue; - case DataKind.I8: - return long.MinValue; - case DataKind.U8: - return 0; + case DataKind.I1: + return sbyte.MinValue; + case DataKind.U1: + return byte.MinValue; + case DataKind.I2: + return short.MinValue; + case DataKind.U2: + return ushort.MinValue; + case DataKind.I4: + return int.MinValue; + case DataKind.U4: + return uint.MinValue; + case DataKind.I8: + return long.MinValue; + case DataKind.U8: + return 0; } return 1; @@ -140,38 +140,38 @@ public static Type ToType(this DataKind kind) { switch (kind) { - case DataKind.I1: - return typeof(DvInt1); - case DataKind.U1: - return typeof(byte); - case DataKind.I2: - return typeof(DvInt2); - case DataKind.U2: - return typeof(ushort); - case DataKind.I4: - return typeof(DvInt4); - case DataKind.U4: - return typeof(uint); - case DataKind.I8: - return typeof(DvInt8); - case DataKind.U8: - return typeof(ulong); - case DataKind.R4: - return typeof(Single); - case DataKind.R8: - return typeof(Double); - case DataKind.TX: - return typeof(DvText); - case DataKind.BL: - return typeof(DvBool); - case DataKind.TS: - return typeof(DvTimeSpan); - case DataKind.DT: - return typeof(DvDateTime); - case DataKind.DZ: - return typeof(DvDateTimeZone); - case DataKind.UG: - return typeof(UInt128); + case DataKind.I1: + return typeof(DvInt1); + case DataKind.U1: + return typeof(byte); + case DataKind.I2: + return typeof(DvInt2); + case DataKind.U2: + return typeof(ushort); + case DataKind.I4: + return typeof(DvInt4); + case DataKind.U4: + return typeof(uint); + case DataKind.I8: + return typeof(DvInt8); + case DataKind.U8: + return typeof(ulong); + case DataKind.R4: + return typeof(Single); + case DataKind.R8: + return typeof(Double); + case DataKind.TX: + return typeof(DvText); + case DataKind.BL: + return typeof(DvBool); + case DataKind.TS: + return typeof(DvTimeSpan); + case DataKind.DT: + return typeof(DvDateTime); + case DataKind.DZ: + return typeof(DvDateTimeZone); + case DataKind.UG: + return typeof(UInt128); } return null; @@ -185,29 +185,29 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) Contracts.CheckValueOrNull(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1)) + if (type == typeof(DvInt1) || type == typeof(sbyte) || type == typeof(sbyte?)) kind = DataKind.I1; - else if (type == typeof(byte)) + else if (type == typeof(byte) || type == typeof(byte?)) kind = DataKind.U1; - else if (type == typeof(DvInt2)) + else if (type == typeof(DvInt2)|| type== typeof(short) || type == typeof(short?)) kind = DataKind.I2; - else if (type == typeof(ushort)) + else if (type == typeof(ushort)|| type == typeof(ushort?)) kind = DataKind.U2; - else if (type == typeof(DvInt4)) + else if (type == typeof(DvInt4) || type == typeof(int)|| type == typeof(int?)) kind = DataKind.I4; - else if (type == typeof(uint)) + else if (type == typeof(uint)|| type == typeof(uint?)) kind = DataKind.U4; - else if (type == typeof(DvInt8)) + else if (type == typeof(DvInt8) || type==typeof(long)|| type == typeof(long?)) kind = DataKind.I8; - else if (type == typeof(ulong)) + else if (type == typeof(ulong)|| type == typeof(ulong?)) kind = DataKind.U8; - else if (type == typeof(Single)) + else if (type == typeof(Single)|| type == typeof(Single?)) kind = DataKind.R4; - else if (type == typeof(Double)) + else if (type == typeof(Double)|| type == typeof(Double?)) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?)) + else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; @@ -234,38 +234,38 @@ public static string GetString(this DataKind kind) { switch (kind) { - case DataKind.I1: - return "I1"; - case DataKind.I2: - return "I2"; - case DataKind.I4: - return "I4"; - case DataKind.I8: - return "I8"; - case DataKind.U1: - return "U1"; - case DataKind.U2: - return "U2"; - case DataKind.U4: - return "U4"; - case DataKind.U8: - return "U8"; - case DataKind.R4: - return "R4"; - case DataKind.R8: - return "R8"; - case DataKind.BL: - return "BL"; - case DataKind.TX: - return "TX"; - case DataKind.TS: - return "TS"; - case DataKind.DT: - return "DT"; - case DataKind.DZ: - return "DZ"; - case DataKind.UG: - return "UG"; + case DataKind.I1: + return "I1"; + case DataKind.I2: + return "I2"; + case DataKind.I4: + return "I4"; + case DataKind.I8: + return "I8"; + case DataKind.U1: + return "U1"; + case DataKind.U2: + return "U2"; + case DataKind.U4: + return "U4"; + case DataKind.U8: + return "U8"; + case DataKind.R4: + return "R4"; + case DataKind.R8: + return "R8"; + case DataKind.BL: + return "BL"; + case DataKind.TX: + return "TX"; + case DataKind.TS: + return "TS"; + case DataKind.DT: + return "DT"; + case DataKind.DZ: + return "DZ"; + case DataKind.UG: + return "UG"; } return ""; } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 42b85ae20f..cc2c948206 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -9,6 +9,7 @@ using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System; using System.Collections.Generic; using System.Linq; using Xunit; @@ -205,5 +206,207 @@ public class IrisPrediction public float[] PredictedLabels; } + public class ConversionSimpleClass + { + public int fInt; + public uint fuInt; + public short fShort; + public ushort fuShort; + public sbyte fsByte; + public byte fByte; + public long fLong; + public ulong fuLong; + public float fFloat; + public double fDouble; + public bool fBool; + public string fString; + } + + public class ConversionNullalbeClass + { + public int? fInt; + public uint? fuInt; + public short? fShort; + public ushort? fuShort; + public sbyte? fsByte; + public byte? fByte; + public long? fLong; + public ulong? fuLong; + public float? fFloat; + public double? fDouble; + public bool? fBool; + public string fString; + } + + public bool CompareObjectValues(object x, object y) + { + //handle string conversion. + //by default behaviour for DvText is to be empty string, while for string is null. + //so if we do roundtrip string-> DvText -> string all null string become empty strings. + //therefore replace all null values to empty string if field is string. + if (x.GetType() == typeof(string) && x == null) + x = ""; + if (y.GetType() == typeof(string) && y == null) + y = ""; + if (x == null && y == null) + return true; + if (x == null && y != null) + return false; + return (x.Equals(y)); + } + + public bool CompareThrougReflection(T x, T y) + { + foreach (var field in typeof(T).GetFields()) + { + var xvalue = field.GetValue(x); + var yvalue = field.GetValue(y); + if (field.FieldType.IsArray) + { + if (!CompareArrayValues(xvalue as Array, yvalue as Array)) + return false; + } + else + { + if (!CompareObjectValues(xvalue, yvalue)) + return false; + } + + } + return true; + } + + public bool CompareArrayValues(Array x, Array y) + { + if (x == null && y == null) return true; + if ((x == null && y != null) || (y == null && x != null)) + return false; + if (x.Length != y.Length) + return false; + for (int i = 0; i < x.Length; i++) + if (!CompareObjectValues(x.GetValue(i), y.GetValue(i))) + return false; + return true; + } + + public class ClassWithConstField + { + public const string ConstString = "N"; + public string fString; + public const int ConstInt = 100; + public int fInt; + } + + [Fact] + public void BackAndForthConversionWithBasicTypes() + { + var data = new List() + { + new ConversionSimpleClass(){ fInt=int.MaxValue-1, fuInt=uint.MaxValue-1, fBool=true, fsByte=sbyte.MaxValue-1, fByte = byte.MaxValue-1, + fDouble =double.MaxValue-1, fFloat=float.MaxValue-1, fLong=long.MaxValue-1, fuLong = ulong.MaxValue-1, + fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString="ha"}, + new ConversionSimpleClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, + fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, + fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, + new ConversionSimpleClass(){}, + }; + + var dataNullable = new List() + { + new ConversionNullalbeClass(){ fInt=int.MaxValue-1, fuInt=uint.MaxValue-1, fBool=true, fsByte=sbyte.MaxValue-1, fByte = byte.MaxValue-1, + fDouble =double.MaxValue-1, fFloat=float.MaxValue-1, fLong=long.MaxValue-1, fuLong = ulong.MaxValue-1, + fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString="ha"}, + new ConversionNullalbeClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, + fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, + fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, + new ConversionNullalbeClass() + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + dataView = ComponentCreation.CreateDataView(env, dataNullable); + var enumeratorNullable = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalNullableEnumerator = dataNullable.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullableEnumerator.MoveNext()) + { + Assert.True(CompareThrougReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullableEnumerator.MoveNext()); + } + } + + [Fact] + public void ClassWithConstFieldsConversion() + { + var data = new List() + { + new ClassWithConstField(){ fInt=1, fString ="lala" }, + new ClassWithConstField(){ fInt=-1, fString ="" }, + new ClassWithConstField(){ fInt=0, fString =null } + }; + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + + public class ClassWithArrays + { + public string[] fString; + public int[] fInt; + public uint[] fuInt; + public short[] fShort; + public ushort[] fuShort; + public sbyte[] fsByte; + public byte[] fByte; + public long[] fLong; + public ulong[] fuLong; + public float[] fFloat; + public double[] fDouble; + public bool[] fBool; + } + + [Fact] + public void BackAndForthConversionWithArrays() + { + var data = new List() + { + new ClassWithArrays(){ fInt = new int[3]{ 0,1,2}, fFloat = new float[3]{ -0.99f, 0f, 0.99f}, fString =new string[2]{ "hola", "lola"}, + fBool =new bool[2]{true, false }, fByte = new byte[3]{ 0,124,255}, fDouble=new double[3]{ -1,0, 1}, fLong = new long[]{ 0,1,2} , + fsByte = new sbyte[3]{ -128,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue}, + fuLong = new ulong[2]{ ulong.MaxValue, 0}, fuShort = new ushort[2]{ 0, ushort.MaxValue} + }, + new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } } } From 581cb42d5b4c594470ae9cd115a0d54e35932b4d Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 11:58:55 -0700 Subject: [PATCH 02/10] fix tests, address some comments --- .../DataViewConstructionUtils.cs | 14 +-- src/Microsoft.ML.Api/TypedCursor.cs | 34 +++---- .../CollectionDataSourceTests.cs | 93 +++++++++++++++++-- 3 files changed, 109 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 962481f51b..bbda2a18c0 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -184,43 +184,43 @@ private Delegate CreateGetter(int index) { // Bool -> DvBool Ch.Assert(colType.IsBool); - return CreateGetterDelegate(index, (x) => x); + return CreateGetterDelegate(index, x => x); } else if (outputType == typeof(bool?)) { // Bool? -> DvBool Ch.Assert(colType.IsBool); - return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvBool.NA); + return CreateGetterDelegate(index, x => x ?? DvBool.NA); } else if (outputType == typeof(int)) { // int -> DvInt4 Ch.Assert(colType == NumberType.I4); - return CreateGetterDelegate(index, (x) => x); + return CreateGetterDelegate(index, x => x); } else if (outputType == typeof(int?)) { // int -> DvInt4 Ch.Assert(colType == NumberType.I4); - return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt4.NA); + return CreateGetterDelegate(index, x => x ?? DvInt4.NA); } else if (outputType == typeof(short)) { // short -> DvInt2 Ch.Assert(colType == NumberType.I2); - return CreateGetterDelegate(index, (x) => x); + return CreateGetterDelegate(index, x => x); } else if (outputType == typeof(short?)) { // short? -> DvInt2 Ch.Assert(colType == NumberType.I2); - return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt2.NA); + return CreateGetterDelegate(index, x => x ?? DvInt2.NA); } else if (outputType == typeof(long)) { // long -> DvInt8 Ch.Assert(colType == NumberType.I8); - return CreateGetterDelegate(index, (x) => x); + return CreateGetterDelegate(index, x => x); } else if (outputType == typeof(long?)) { diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 31a4c99bf3..6298067477 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -280,32 +280,32 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit if (fieldType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateVBufferSetter(input, index, poke, peek, (x) => x.ToString()); + return CreateVBufferSetter(input, index, poke, peek, x => x.ToString()); } else if (fieldType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); - return CreateVBufferSetter(input, index, poke, peek, (x) => Convert.ToBoolean(x.RawValue)); + return CreateVBufferSetter(input, index, poke, peek, x => Convert.ToBoolean(x.RawValue)); } else if (fieldType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + return CreateVBufferSetter(input, index, poke, peek, x => (int)x); } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + return CreateVBufferSetter(input, index, poke, peek, x => (short)x); } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + return CreateVBufferSetter(input, index, poke, peek, x => (long)x); } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateVBufferSetter(input, index, poke, peek, (x) => x.RawValue); + return CreateVBufferSetter(input, index, poke, peek, x => (sbyte)x); } // VBuffer -> T[] @@ -329,67 +329,67 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit // DvText -> String Ch.Assert(colType.IsText); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.ToString()); + return CreateActionSetter(input, index, poke, x => x.ToString()); } else if (fieldType == typeof(bool)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => Convert.ToBoolean(x.RawValue)); + return CreateActionSetter(input, index, poke, x => Convert.ToBoolean(x.RawValue)); } else if (fieldType == typeof(bool?)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (bool?)null : (x.IsTrue ? true : false)); + return CreateActionSetter(input, index, poke, x => (bool?)x); } else if (fieldType == typeof(int)) { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.RawValue); + return CreateActionSetter(input, index, poke, x => (int)x); } else if (fieldType == typeof(int?)) { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (int?)null : x.RawValue); + return CreateActionSetter(input, index, poke, x => x.IsNA ? (int?)null : (int)x); } else if (fieldType == typeof(short)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.RawValue); + return CreateActionSetter(input, index, poke, x => (short)x); } else if (fieldType == typeof(short?)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (short?)null : x.RawValue); + return CreateActionSetter(input, index, poke, x => x.IsNA ? (short?)null : (short)x); } else if (fieldType == typeof(long)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.RawValue); + return CreateActionSetter(input, index, poke, x => (long)x); } else if (fieldType == typeof(long?)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (long?)null : x.RawValue); + return CreateActionSetter(input, index, poke, x => x.IsNA ? (long?)null : (long)x); } else if (fieldType == typeof(sbyte)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.RawValue); + return CreateActionSetter(input, index, poke, x => (sbyte)x); } else if (fieldType == typeof(sbyte?)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, (x) => x.IsNA ? (sbyte?)null : x.RawValue); + return CreateActionSetter(input, index, poke, x => x.IsNA ? (sbyte?)null : (sbyte)x); } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index cc2c948206..0ed4d7cdbb 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -12,6 +12,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using Xunit; using Xunit.Abstractions; @@ -238,15 +239,15 @@ public class ConversionNullalbeClass public string fString; } - public bool CompareObjectValues(object x, object y) + public bool CompareObjectValues(object x, object y, Type type) { //handle string conversion. //by default behaviour for DvText is to be empty string, while for string is null. //so if we do roundtrip string-> DvText -> string all null string become empty strings. //therefore replace all null values to empty string if field is string. - if (x.GetType() == typeof(string) && x == null) + if (type == typeof(string) && x == null) x = ""; - if (y.GetType() == typeof(string) && y == null) + if (type == typeof(string) && y == null) y = ""; if (x == null && y == null) return true; @@ -268,7 +269,7 @@ public bool CompareThrougReflection(T x, T y) } else { - if (!CompareObjectValues(xvalue, yvalue)) + if (!CompareObjectValues(xvalue, yvalue, field.FieldType)) return false; } @@ -284,7 +285,7 @@ public bool CompareArrayValues(Array x, Array y) if (x.Length != y.Length) return false; for (int i = 0; i < x.Length; i++) - if (!CompareObjectValues(x.GetValue(i), y.GetValue(i))) + if (!CompareObjectValues(x.GetValue(i), y.GetValue(i), x.GetType().GetElementType())) return false; return true; } @@ -304,10 +305,13 @@ public void BackAndForthConversionWithBasicTypes() { new ConversionSimpleClass(){ fInt=int.MaxValue-1, fuInt=uint.MaxValue-1, fBool=true, fsByte=sbyte.MaxValue-1, fByte = byte.MaxValue-1, fDouble =double.MaxValue-1, fFloat=float.MaxValue-1, fLong=long.MaxValue-1, fuLong = ulong.MaxValue-1, - fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString="ha"}, + fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString=null}, new ConversionSimpleClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, - fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, + fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, + new ConversionSimpleClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue+1, fBool=true, fsByte=sbyte.MinValue+1, fByte = byte.MinValue+1, + fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue+1, + fShort =short.MinValue+1, fuShort = ushort.MinValue+1, fString=""}, new ConversionSimpleClass(){}, }; @@ -319,6 +323,9 @@ public void BackAndForthConversionWithBasicTypes() new ConversionNullalbeClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, + new ConversionNullalbeClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue+1, fBool=true, fsByte=sbyte.MinValue+1, fByte = byte.MinValue+1, + fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue+1, + fShort =short.MinValue+1, fuShort = ushort.MinValue+1, fString=""}, new ConversionNullalbeClass() }; @@ -344,6 +351,76 @@ public void BackAndForthConversionWithBasicTypes() } } + public class ConversionNotSupportedMinValueClass + { + public int fInt; + public long fLong; + public short fShort; + public sbyte fSByte; + } + + [Fact] + public void ConversionExceptionsBehavior() + { + using (var env = new TlcEnvironment()) + { + var data = new ConversionNotSupportedMinValueClass[1]; + foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields()) + { + data[0] = new ConversionNotSupportedMinValueClass(); + bool gotException = false; + FieldInfo fi; + if ((fi = field.FieldType.GetField("MinValue")) != null) + { + field.SetValue(data[0], fi.GetValue(null)); + } + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); + try + { + enumerator.MoveNext(); + } + catch + { + gotException = true; + } + Assert.True(gotException); + } + } + } + + public class ConversionLossMinValueClass + { + public int? fInt; + public long? fLong; + public short? fShort; + public sbyte? fSByte; + } + + [Fact] + public void ConversionMinValueToNullBehavior() + { + using (var env = new TlcEnvironment()) + { + var data = new List(){ + new ConversionLossMinValueClass(){ fSByte = null,fInt = null,fLong = null,fShort = null}, + new ConversionLossMinValueClass(){fSByte = sbyte.MinValue,fInt = int.MinValue,fLong = long.MinValue,fShort = short.MinValue} + }; + + foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) + { + + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); + while (enumerator.MoveNext()) + { + Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && + enumerator.Current.fSByte == null && enumerator.Current.fShort == null); + } + } + } + } + [Fact] public void ClassWithConstFieldsConversion() { @@ -390,7 +467,7 @@ public void BackAndForthConversionWithArrays() { new ClassWithArrays(){ fInt = new int[3]{ 0,1,2}, fFloat = new float[3]{ -0.99f, 0f, 0.99f}, fString =new string[2]{ "hola", "lola"}, fBool =new bool[2]{true, false }, fByte = new byte[3]{ 0,124,255}, fDouble=new double[3]{ -1,0, 1}, fLong = new long[]{ 0,1,2} , - fsByte = new sbyte[3]{ -128,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue}, + fsByte = new sbyte[3]{ -127,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue}, fuLong = new ulong[2]{ ulong.MaxValue, 0}, fuShort = new ushort[2]{ 0, ushort.MaxValue} }, new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} } From 62091c90fd7c1ef09f664ef7055976c726e3033a Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 13:54:55 -0700 Subject: [PATCH 03/10] support nullable arrays. --- .../DataViewConstructionUtils.cs | 42 +++++++++++++++---- src/Microsoft.ML.Api/TypedCursor.cs | 38 ++++++++++++++--- .../CollectionDataSourceTests.cs | 40 +++++++++++++++++- 3 files changed, 107 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index bbda2a18c0..7e3d841b93 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -129,36 +129,64 @@ private Delegate CreateGetter(int index) if (outputType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateArrayGetterDelegate(index, (x) => new DvText(x)); + return CreateArrayGetterDelegate(index, x => new DvText(x)); } else if (outputType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateArrayGetterDelegate(index, (x) => x); + return CreateArrayGetterDelegate(index, x => x); + } + else if (outputType.GetElementType() == typeof(int?)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateArrayGetterDelegate(index, x => x ?? DvInt4.NA); } else if (outputType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateArrayGetterDelegate(index, (x) => x); + return CreateArrayGetterDelegate(index, x => x); + } + else if (outputType.GetElementType() == typeof(long?)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateArrayGetterDelegate(index, x => x ?? DvInt8.NA); } else if (outputType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateArrayGetterDelegate(index, (x) => x); + return CreateArrayGetterDelegate(index, x => x); + } + else if (outputType.GetElementType() == typeof(short?)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateArrayGetterDelegate(index, x => x ?? DvInt2.NA); } else if (outputType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateArrayGetterDelegate(index, (x) => x); + return CreateArrayGetterDelegate(index, x => x); + } + else if (outputType.GetElementType() == typeof(sbyte?)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateArrayGetterDelegate(index, x => x ?? DvInt1.NA); } else if (outputType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); - return CreateArrayGetterDelegate(index, (x)=>x); + return CreateArrayGetterDelegate(index, x => x); + } + else if (outputType.GetElementType() == typeof(bool?)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateArrayGetterDelegate(index, x => x ?? DvBool.NA); } // T[] -> VBuffer - Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType); + if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType); + else + Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType); del = CreateDirectArrayGetterDelegate; genericType = outputType.GetElementType(); } diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 6298067477..80ee900377 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -287,29 +287,57 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(colType.ItemType.IsBool); return CreateVBufferSetter(input, index, poke, peek, x => Convert.ToBoolean(x.RawValue)); } + else if (fieldType.GetElementType() == typeof(bool?)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateVBufferSetter(input, index, poke, peek, x => (bool?)x); + } else if (fieldType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); return CreateVBufferSetter(input, index, poke, peek, x => (int)x); } + else if (fieldType.GetElementType() == typeof(int?)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateVBufferSetter(input, index, poke, peek, x => (int?)x); + } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); return CreateVBufferSetter(input, index, poke, peek, x => (short)x); } + else if (fieldType.GetElementType() == typeof(short?)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateVBufferSetter(input, index, poke, peek, x => (short?)x); + } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); return CreateVBufferSetter(input, index, poke, peek, x => (long)x); } + else if (fieldType.GetElementType() == typeof(long?)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateVBufferSetter(input, index, poke, peek, x => (long?)x); + } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); return CreateVBufferSetter(input, index, poke, peek, x => (sbyte)x); } + else if (fieldType.GetElementType() == typeof(sbyte?)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateVBufferSetter(input, index, poke, peek, x => (sbyte?)x); + } // VBuffer -> T[] - Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType); + if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.ItemType.RawType == Nullable.GetUnderlyingType(fieldType.GetElementType())); + else + Ch.Assert(colType.ItemType.RawType == fieldType.GetElementType()); del = CreateVBufferDirectSetter; genericType = fieldType.GetElementType(); } @@ -353,7 +381,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => x.IsNA ? (int?)null : (int)x); + return CreateActionSetter(input, index, poke, x => (int?)x); } else if (fieldType == typeof(short)) { @@ -365,7 +393,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => x.IsNA ? (short?)null : (short)x); + return CreateActionSetter(input, index, poke, x => (short?)x); } else if (fieldType == typeof(long)) { @@ -377,7 +405,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => x.IsNA ? (long?)null : (long)x); + return CreateActionSetter(input, index, poke, x => (long?)x); } else if (fieldType == typeof(sbyte)) { @@ -389,7 +417,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => x.IsNA ? (sbyte?)null : (sbyte)x); + return CreateActionSetter(input, index, poke, x => (sbyte?)x); } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 0ed4d7cdbb..758573cf94 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -459,6 +459,22 @@ public class ClassWithArrays public double[] fDouble; public bool[] fBool; } + public class ClassWithNullableArrays + { + public string[] fString; + public int?[] fInt; + public uint?[] fuInt; + public short?[] fShort; + public ushort?[] fuShort; + public sbyte?[] fsByte; + public byte?[] fByte; + public long?[] fLong; + public ulong?[] fuLong; + public float?[] fFloat; + public double?[] fDouble; + public bool?[] fBool; + } + [Fact] public void BackAndForthConversionWithArrays() @@ -470,9 +486,20 @@ public void BackAndForthConversionWithArrays() fsByte = new sbyte[3]{ -127,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue}, fuLong = new ulong[2]{ ulong.MaxValue, 0}, fuShort = new ushort[2]{ 0, ushort.MaxValue} }, - new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} } + new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{"",null} }, + new ClassWithArrays() }; + var nullableData = new List() + { + new ClassWithNullableArrays(){ fInt = new int?[3]{ null,-1,1}, fFloat = new float?[3]{ -0.99f, null, 0.99f}, fString =new string[2]{ null, ""}, + fBool =new bool?[3]{true,null, false }, fByte = new byte?[4]{ 0,125,null,255}, fDouble=new double?[3]{ -1,null, 1}, fLong = new long?[]{null,-1,1} , + fsByte = new sbyte?[3]{ -127,127,null}, fShort = new short?[3]{ 0, null, 32767 }, fuInt =new uint?[4]{null,42 ,0, uint.MaxValue}, + fuLong = new ulong?[3]{ ulong.MaxValue, null, 0}, fuShort = new ushort?[3]{ 0,null, ushort.MaxValue} + }, + new ClassWithNullableArrays(){ fInt = new int?[3]{ -2,1,0}, fFloat = new float?[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} }, + new ClassWithNullableArrays() + }; using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -483,7 +510,18 @@ public void BackAndForthConversionWithArrays() Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); + var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); + var originalNullalbleEnumerator = nullableData.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) + { + Assert.True(CompareThrougReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } + + } } From ef950ebd5261d800f6caf154876eb086d2e039a7 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 13:56:55 -0700 Subject: [PATCH 04/10] bool case --- src/Microsoft.ML.Api/TypedCursor.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 80ee900377..44b0352919 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -285,7 +285,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit else if (fieldType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); - return CreateVBufferSetter(input, index, poke, peek, x => Convert.ToBoolean(x.RawValue)); + return CreateVBufferSetter(input, index, poke, peek, x => (bool)x); } else if (fieldType.GetElementType() == typeof(bool?)) { @@ -363,7 +363,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => Convert.ToBoolean(x.RawValue)); + return CreateActionSetter(input, index, poke, x => (bool)x); } else if (fieldType == typeof(bool?)) { From 49323f26d04ae250308c46e97b4980f013dd39ab Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 14:01:50 -0700 Subject: [PATCH 05/10] clean test code --- .../CollectionDataSourceTests.cs | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 758573cf94..b2fbba1c74 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -241,10 +241,9 @@ public class ConversionNullalbeClass public bool CompareObjectValues(object x, object y, Type type) { - //handle string conversion. - //by default behaviour for DvText is to be empty string, while for string is null. - //so if we do roundtrip string-> DvText -> string all null string become empty strings. - //therefore replace all null values to empty string if field is string. + // By default behaviour for DvText is to be empty string, while for string is null. + // So if we do roundtrip string-> DvText -> string all null string become empty strings. + // Therefore replace all null values to empty string if field is string. if (type == typeof(string) && x == null) x = ""; if (type == typeof(string) && y == null) @@ -272,7 +271,6 @@ public bool CompareThrougReflection(T x, T y) if (!CompareObjectValues(xvalue, yvalue, field.FieldType)) return false; } - } return true; } @@ -323,9 +321,9 @@ public void BackAndForthConversionWithBasicTypes() new ConversionNullalbeClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, - new ConversionNullalbeClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue+1, fBool=true, fsByte=sbyte.MinValue+1, fByte = byte.MinValue+1, - fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue+1, - fShort =short.MinValue+1, fuShort = ushort.MinValue+1, fString=""}, + new ConversionNullalbeClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue, fBool=false, fsByte=sbyte.MinValue+1, fByte = byte.MinValue, + fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue, + fShort =short.MinValue+1, fuShort = ushort.MinValue, fString=""}, new ConversionNullalbeClass() }; @@ -406,10 +404,8 @@ public void ConversionMinValueToNullBehavior() new ConversionLossMinValueClass(){ fSByte = null,fInt = null,fLong = null,fShort = null}, new ConversionLossMinValueClass(){fSByte = sbyte.MinValue,fInt = int.MinValue,fLong = long.MinValue,fShort = short.MinValue} }; - foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) { - var dataView = ComponentCreation.CreateDataView(env, data); var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); while (enumerator.MoveNext()) @@ -430,16 +426,14 @@ public void ClassWithConstFieldsConversion() new ClassWithConstField(){ fInt=-1, fString ="" }, new ClassWithConstField(){ fInt=0, fString =null } }; + using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); - } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } } @@ -459,6 +453,7 @@ public class ClassWithArrays public double[] fDouble; public bool[] fBool; } + public class ClassWithNullableArrays { public string[] fString; @@ -475,7 +470,6 @@ public class ClassWithNullableArrays public bool?[] fBool; } - [Fact] public void BackAndForthConversionWithArrays() { @@ -500,6 +494,7 @@ public void BackAndForthConversionWithArrays() new ClassWithNullableArrays(){ fInt = new int?[3]{ -2,1,0}, fFloat = new float?[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} }, new ClassWithNullableArrays() }; + using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -521,7 +516,5 @@ public void BackAndForthConversionWithArrays() Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } - - } } From d556d1338641051ad6e0b7d94f761f09119e0ece Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 14:15:01 -0700 Subject: [PATCH 06/10] address comments --- src/Microsoft.ML.Api/ApiUtils.cs | 3 +-- src/Microsoft.ML.Api/DataViewConstructionUtils.cs | 5 ++--- src/Microsoft.ML.Api/SchemaDefinition.cs | 6 ++++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 8b85c24d63..4b7daf0a74 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -6,7 +6,6 @@ using System.Reflection; using System.Reflection.Emit; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Runtime.Api { @@ -19,7 +18,7 @@ internal static class ApiUtils private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary based solution. - // DvTexts, strings, arrays, and VBuffers. + // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 7e3d841b93..5dac6a6384 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -194,7 +194,6 @@ private Delegate CreateGetter(int index) { // VBuffer -> VBuffer // REVIEW: Do we care about accomodating VBuffer -> VBuffer? - // REVIEW: why it's int and not long? Ch.Assert(outputType.IsGenericType); Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType); @@ -228,7 +227,7 @@ private Delegate CreateGetter(int index) } else if (outputType == typeof(int?)) { - // int -> DvInt4 + // int? -> DvInt4 Ch.Assert(colType == NumberType.I4); return CreateGetterDelegate(index, x => x ?? DvInt4.NA); } @@ -264,7 +263,7 @@ private Delegate CreateGetter(int index) } else if (outputType == typeof(sbyte?)) { - // sbyte -> DvInt1 + // sbyte? -> DvInt1 Ch.Assert(colType == NumberType.I1); return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt1.NA); } diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index 522fb3510c..559e3a81ee 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -327,8 +327,10 @@ public static SchemaDefinition Create(Type userType) // This field does not need a column. // REVIEW: maybe validate the channel attribute now, instead // of later at cursor creation. - // Const fields not need to be mapped. - if (fieldInfo.FieldType == typeof(IChannel) || fieldInfo.IsLiteral) + if (fieldInfo.FieldType == typeof(IChannel)) + continue; + // Const fields do not need to be mapped. + if (fieldInfo.IsLiteral) continue; if (fieldInfo.GetCustomAttribute() != null) From a7bb6e171730756ae62db39a55c8bf8efea619af Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 14:23:43 -0700 Subject: [PATCH 07/10] more comments! --- src/Microsoft.ML.Api/DataViewConstructionUtils.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 5dac6a6384..b13e114dee 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -129,7 +129,7 @@ private Delegate CreateGetter(int index) if (outputType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateArrayGetterDelegate(index, x => new DvText(x)); + return CreateArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } else if (outputType.GetElementType() == typeof(int)) { @@ -205,7 +205,7 @@ private Delegate CreateGetter(int index) { // String -> DvText Ch.Assert(colType.IsText); - return CreateGetterDelegate(index, (x) => x == null ? DvText.NA : new DvText(x)); + return CreateGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } else if (outputType == typeof(bool)) { @@ -253,19 +253,19 @@ private Delegate CreateGetter(int index) { // long? -> DvInt8 Ch.Assert(colType == NumberType.I8); - return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt8.NA); + return CreateGetterDelegate(index, x => x ?? DvInt8.NA); } else if (outputType == typeof(sbyte)) { // sbyte -> DvInt1 Ch.Assert(colType == NumberType.I1); - return CreateGetterDelegate(index, (x) => (DvInt1)x); + return CreateGetterDelegate(index, x => x); } else if (outputType == typeof(sbyte?)) { // sbyte? -> DvInt1 Ch.Assert(colType == NumberType.I1); - return CreateGetterDelegate(index, (x) => x.HasValue ? x.Value : DvInt1.NA); + return CreateGetterDelegate(index, x => x ?? DvInt1.NA); } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) From d5143e6aced0c55aa58b138a11fb51013aa31ba1 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 14:26:49 -0700 Subject: [PATCH 08/10] give functions are better name? --- .../DataViewConstructionUtils.cs | 48 ++++++++--------- src/Microsoft.ML.Api/TypedCursor.cs | 52 +++++++++---------- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index b13e114dee..96151c616b 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -129,57 +129,57 @@ private Delegate CreateGetter(int index) if (outputType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); + return CreateConvertingArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } else if (outputType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(int?)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateArrayGetterDelegate(index, x => x ?? DvInt4.NA); + return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt4.NA); } else if (outputType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(long?)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateArrayGetterDelegate(index, x => x ?? DvInt8.NA); + return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt8.NA); } else if (outputType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(short?)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateArrayGetterDelegate(index, x => x ?? DvInt2.NA); + return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt2.NA); } else if (outputType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(sbyte?)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateArrayGetterDelegate(index, x => x ?? DvInt1.NA); + return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt1.NA); } else if (outputType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); - return CreateArrayGetterDelegate(index, x => x); + return CreateConvertingArrayGetterDelegate(index, x => x); } else if (outputType.GetElementType() == typeof(bool?)) { Ch.Assert(colType.ItemType.IsBool); - return CreateArrayGetterDelegate(index, x => x ?? DvBool.NA); + return CreateConvertingArrayGetterDelegate(index, x => x ?? DvBool.NA); } // T[] -> VBuffer @@ -205,67 +205,67 @@ private Delegate CreateGetter(int index) { // String -> DvText Ch.Assert(colType.IsText); - return CreateGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); + return CreateConvertingGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } else if (outputType == typeof(bool)) { // Bool -> DvBool Ch.Assert(colType.IsBool); - return CreateGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(bool?)) { // Bool? -> DvBool Ch.Assert(colType.IsBool); - return CreateGetterDelegate(index, x => x ?? DvBool.NA); + return CreateConvertingGetterDelegate(index, x => x ?? DvBool.NA); } else if (outputType == typeof(int)) { // int -> DvInt4 Ch.Assert(colType == NumberType.I4); - return CreateGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(int?)) { // int? -> DvInt4 Ch.Assert(colType == NumberType.I4); - return CreateGetterDelegate(index, x => x ?? DvInt4.NA); + return CreateConvertingGetterDelegate(index, x => x ?? DvInt4.NA); } else if (outputType == typeof(short)) { // short -> DvInt2 Ch.Assert(colType == NumberType.I2); - return CreateGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(short?)) { // short? -> DvInt2 Ch.Assert(colType == NumberType.I2); - return CreateGetterDelegate(index, x => x ?? DvInt2.NA); + return CreateConvertingGetterDelegate(index, x => x ?? DvInt2.NA); } else if (outputType == typeof(long)) { // long -> DvInt8 Ch.Assert(colType == NumberType.I8); - return CreateGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(long?)) { // long? -> DvInt8 Ch.Assert(colType == NumberType.I8); - return CreateGetterDelegate(index, x => x ?? DvInt8.NA); + return CreateConvertingGetterDelegate(index, x => x ?? DvInt8.NA); } else if (outputType == typeof(sbyte)) { // sbyte -> DvInt1 Ch.Assert(colType == NumberType.I1); - return CreateGetterDelegate(index, x => x); + return CreateConvertingGetterDelegate(index, x => x); } else if (outputType == typeof(sbyte?)) { // sbyte? -> DvInt1 Ch.Assert(colType == NumberType.I1); - return CreateGetterDelegate(index, x => x ?? DvInt1.NA); + return CreateConvertingGetterDelegate(index, x => x ?? DvInt1.NA); } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) @@ -284,7 +284,7 @@ private Delegate CreateGetter(int index) return (Delegate)meth.Invoke(this, new object[] { index }); } - private Delegate CreateArrayGetterDelegate(int index, Func convert) + private Delegate CreateConvertingArrayGetterDelegate(int index, Func convert) { var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); @@ -301,7 +301,7 @@ private Delegate CreateArrayGetterDelegate(int index, Func(int index, Func convert) + private Delegate CreateConvertingGetterDelegate(int index, Func convert) { var peek = DataView._peeks[index] as Peek; Ch.AssertValue(peek); diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 44b0352919..2a325d8101 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -280,57 +280,57 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit if (fieldType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateVBufferSetter(input, index, poke, peek, x => x.ToString()); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => x.ToString()); } else if (fieldType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); - return CreateVBufferSetter(input, index, poke, peek, x => (bool)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool)x); } else if (fieldType.GetElementType() == typeof(bool?)) { Ch.Assert(colType.ItemType.IsBool); - return CreateVBufferSetter(input, index, poke, peek, x => (bool?)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool?)x); } else if (fieldType.GetElementType() == typeof(int)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateVBufferSetter(input, index, poke, peek, x => (int)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); } else if (fieldType.GetElementType() == typeof(int?)) { Ch.Assert(colType.ItemType == NumberType.I4); - return CreateVBufferSetter(input, index, poke, peek, x => (int?)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int?)x); } else if (fieldType.GetElementType() == typeof(short)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateVBufferSetter(input, index, poke, peek, x => (short)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short)x); } else if (fieldType.GetElementType() == typeof(short?)) { Ch.Assert(colType.ItemType == NumberType.I2); - return CreateVBufferSetter(input, index, poke, peek, x => (short?)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short?)x); } else if (fieldType.GetElementType() == typeof(long)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateVBufferSetter(input, index, poke, peek, x => (long)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long)x); } else if (fieldType.GetElementType() == typeof(long?)) { Ch.Assert(colType.ItemType == NumberType.I8); - return CreateVBufferSetter(input, index, poke, peek, x => (long?)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long?)x); } else if (fieldType.GetElementType() == typeof(sbyte)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateVBufferSetter(input, index, poke, peek, x => (sbyte)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte)x); } else if (fieldType.GetElementType() == typeof(sbyte?)) { Ch.Assert(colType.ItemType == NumberType.I1); - return CreateVBufferSetter(input, index, poke, peek, x => (sbyte?)x); + return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte?)x); } // VBuffer -> T[] @@ -338,7 +338,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(colType.ItemType.RawType == Nullable.GetUnderlyingType(fieldType.GetElementType())); else Ch.Assert(colType.ItemType.RawType == fieldType.GetElementType()); - del = CreateVBufferDirectSetter; + del = CreateDirectVBufferSetter; genericType = fieldType.GetElementType(); } else if (colType.IsVector) @@ -357,67 +357,67 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit // DvText -> String Ch.Assert(colType.IsText); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => x.ToString()); + return CreateConvertingActionSetter(input, index, poke, x => x.ToString()); } else if (fieldType == typeof(bool)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (bool)x); + return CreateConvertingActionSetter(input, index, poke, x => (bool)x); } else if (fieldType == typeof(bool?)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (bool?)x); + return CreateConvertingActionSetter(input, index, poke, x => (bool?)x); } else if (fieldType == typeof(int)) { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (int)x); + return CreateConvertingActionSetter(input, index, poke, x => (int)x); } else if (fieldType == typeof(int?)) { Ch.Assert(colType == NumberType.I4); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (int?)x); + return CreateConvertingActionSetter(input, index, poke, x => (int?)x); } else if (fieldType == typeof(short)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (short)x); + return CreateConvertingActionSetter(input, index, poke, x => (short)x); } else if (fieldType == typeof(short?)) { Ch.Assert(colType == NumberType.I2); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (short?)x); + return CreateConvertingActionSetter(input, index, poke, x => (short?)x); } else if (fieldType == typeof(long)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (long)x); + return CreateConvertingActionSetter(input, index, poke, x => (long)x); } else if (fieldType == typeof(long?)) { Ch.Assert(colType == NumberType.I8); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (long?)x); + return CreateConvertingActionSetter(input, index, poke, x => (long?)x); } else if (fieldType == typeof(sbyte)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (sbyte)x); + return CreateConvertingActionSetter(input, index, poke, x => (sbyte)x); } else if (fieldType == typeof(sbyte?)) { Ch.Assert(colType == NumberType.I1); Ch.Assert(peek == null); - return CreateActionSetter(input, index, poke, x => (sbyte?)x); + return CreateConvertingActionSetter(input, index, poke, x => (sbyte?)x); } // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) @@ -436,7 +436,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit return (Action)meth.Invoke(this, new object[] { input, index, poke, peek }); } - private Action CreateVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) + private Action CreateConvertingVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -458,7 +458,7 @@ private Action CreateVBufferSetter(IRow input, int col, Delega }; } - private Action CreateVBufferDirectSetter(IRow input, int col, Delegate poke, Delegate peek) + private Action CreateDirectVBufferSetter(IRow input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter>(col); var typedPoke = poke as Poke; @@ -496,7 +496,7 @@ private Action CreateVBufferDirectSetter(IRow input, int col, Delega }; } - private static Action CreateActionSetter(IRow input, int col, Delegate poke, Func convert) + private static Action CreateConvertingActionSetter(IRow input, int col, Delegate poke, Func convert) { var getter = input.GetGetter(col); var typedPoke = poke as Poke; From 5b7dc4aaf33067b88664afe54d276621bddd7a34 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 14:35:43 -0700 Subject: [PATCH 09/10] add //Review comment regarding conversion --- src/Microsoft.ML.Api/DataViewConstructionUtils.cs | 3 +++ src/Microsoft.ML.Api/TypedCursor.cs | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 96151c616b..6ecff5b204 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -284,6 +284,9 @@ private Delegate CreateGetter(int index) return (Delegate)meth.Invoke(this, new object[] { index }); } + // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower + // than the 'direct' getter. We don't have good indication of this to the user, and the selection + // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). private Delegate CreateConvertingArrayGetterDelegate(int index, Func convert) { var peek = DataView._peeks[index] as Peek; diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 2a325d8101..2ba9eeb23a 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -436,6 +436,9 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit return (Action)meth.Invoke(this, new object[] { input, index, poke, peek }); } + // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower + // than the 'direct' getter. We don't have good indication of this to the user, and the selection + // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). private Action CreateConvertingVBufferSetter(IRow input, int col, Delegate poke, Delegate peek, Func convert) { var getter = input.GetGetter>(col); From 5f69a4c3c4ab179dc07c1cfa80372136782a0daa Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 19 Jul 2018 15:56:34 -0700 Subject: [PATCH 10/10] address pete comments and format tests properly --- .../CollectionDataSourceTests.cs | 189 +++++++++++++----- 1 file changed, 141 insertions(+), 48 deletions(-) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index b2fbba1c74..87e23952d6 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -252,10 +252,10 @@ public bool CompareObjectValues(object x, object y, Type type) return true; if (x == null && y != null) return false; - return (x.Equals(y)); + return x.Equals(y); } - public bool CompareThrougReflection(T x, T y) + public bool CompareThroughReflection(T x, T y) { foreach (var field in typeof(T).GetFields()) { @@ -297,33 +297,105 @@ public class ClassWithConstField } [Fact] - public void BackAndForthConversionWithBasicTypes() + public void RoundTripConversionWithBasicTypes() { - var data = new List() + var data = new List { - new ConversionSimpleClass(){ fInt=int.MaxValue-1, fuInt=uint.MaxValue-1, fBool=true, fsByte=sbyte.MaxValue-1, fByte = byte.MaxValue-1, - fDouble =double.MaxValue-1, fFloat=float.MaxValue-1, fLong=long.MaxValue-1, fuLong = ulong.MaxValue-1, - fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString=null}, - new ConversionSimpleClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, - fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, - fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, - new ConversionSimpleClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue+1, fBool=true, fsByte=sbyte.MinValue+1, fByte = byte.MinValue+1, - fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue+1, - fShort =short.MinValue+1, fuShort = ushort.MinValue+1, fString=""}, - new ConversionSimpleClass(){}, + new ConversionSimpleClass() + { + fInt = int.MaxValue - 1, + fuInt = uint.MaxValue - 1, + fBool = true, + fsByte = sbyte.MaxValue - 1, + fByte = byte.MaxValue - 1, + fDouble = double.MaxValue - 1, + fFloat = float.MaxValue - 1, + fLong = long.MaxValue - 1, + fuLong = ulong.MaxValue - 1, + fShort = short.MaxValue - 1, + fuShort = ushort.MaxValue - 1, + fString = null + }, + new ConversionSimpleClass() + { + fInt = int.MaxValue, + fuInt = uint.MaxValue, + fBool = true, + fsByte = sbyte.MaxValue, + fByte = byte.MaxValue, + fDouble = double.MaxValue, + fFloat = float.MaxValue, + fLong = long.MaxValue, + fuLong = ulong.MaxValue, + fShort = short.MaxValue, + fuShort = ushort.MaxValue, + fString = "ooh" + }, + new ConversionSimpleClass() + { + fInt = int.MinValue + 1, + fuInt = uint.MinValue + 1, + fBool = false, + fsByte = sbyte.MinValue + 1, + fByte = byte.MinValue + 1, + fDouble = double.MinValue + 1, + fFloat = float.MinValue + 1, + fLong = long.MinValue + 1, + fuLong = ulong.MinValue + 1, + fShort = short.MinValue + 1, + fuShort = ushort.MinValue + 1, + fString = "" + }, + new ConversionSimpleClass() }; - var dataNullable = new List() + var dataNullable = new List { - new ConversionNullalbeClass(){ fInt=int.MaxValue-1, fuInt=uint.MaxValue-1, fBool=true, fsByte=sbyte.MaxValue-1, fByte = byte.MaxValue-1, - fDouble =double.MaxValue-1, fFloat=float.MaxValue-1, fLong=long.MaxValue-1, fuLong = ulong.MaxValue-1, - fShort =short.MaxValue-1, fuShort = ushort.MaxValue-1, fString="ha"}, - new ConversionNullalbeClass(){ fInt=int.MaxValue, fuInt=uint.MaxValue, fBool=true, fsByte=sbyte.MaxValue, fByte = byte.MaxValue, - fDouble =double.MaxValue, fFloat=float.MaxValue, fLong=long.MaxValue, fuLong = ulong.MaxValue, - fShort =short.MaxValue, fuShort = ushort.MaxValue, fString="ooh"}, - new ConversionNullalbeClass(){ fInt=int.MinValue+1, fuInt=uint.MinValue, fBool=false, fsByte=sbyte.MinValue+1, fByte = byte.MinValue, - fDouble =double.MinValue+1, fFloat=float.MinValue+1, fLong=long.MinValue+1, fuLong = ulong.MinValue, - fShort =short.MinValue+1, fuShort = ushort.MinValue, fString=""}, + new ConversionNullalbeClass() + { + fInt = int.MaxValue - 1, + fuInt = uint.MaxValue - 1, + fBool = true, + fsByte = sbyte.MaxValue - 1, + fByte = byte.MaxValue - 1, + fDouble = double.MaxValue - 1, + fFloat = float.MaxValue - 1, + fLong = long.MaxValue - 1, + fuLong = ulong.MaxValue - 1, + fShort = short.MaxValue - 1, + fuShort = ushort.MaxValue - 1, + fString = "ha" + }, + new ConversionNullalbeClass() + { + fInt = int.MaxValue, + fuInt = uint.MaxValue, + fBool = true, + fsByte = sbyte.MaxValue, + fByte = byte.MaxValue, + fDouble = double.MaxValue, + fFloat = float.MaxValue, + fLong = long.MaxValue, + fuLong = ulong.MaxValue, + fShort = short.MaxValue, + fuShort = ushort.MaxValue, + fString = "ooh" + }, + new ConversionNullalbeClass() + { + fInt = int.MinValue + 1, + fuInt = uint.MinValue, + fBool = false, + fsByte = sbyte.MinValue + 1, + fByte = byte.MinValue, + fDouble = double.MinValue + 1, + fFloat = float.MinValue + 1, + fLong = long.MinValue + 1, + fuLong = ulong.MinValue, + fShort = short.MinValue + 1, + fuShort = ushort.MinValue, + fString = "" + }, new ConversionNullalbeClass() }; @@ -334,7 +406,7 @@ public void BackAndForthConversionWithBasicTypes() var originalEnumerator = data.GetEnumerator(); while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); @@ -343,7 +415,7 @@ public void BackAndForthConversionWithBasicTypes() var originalNullableEnumerator = dataNullable.GetEnumerator(); while (enumeratorNullable.MoveNext() && originalNullableEnumerator.MoveNext()) { - Assert.True(CompareThrougReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); } Assert.True(!enumeratorNullable.MoveNext() && !originalNullableEnumerator.MoveNext()); } @@ -366,7 +438,6 @@ public void ConversionExceptionsBehavior() foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields()) { data[0] = new ConversionNotSupportedMinValueClass(); - bool gotException = false; FieldInfo fi; if ((fi = field.FieldType.GetField("MinValue")) != null) { @@ -377,12 +448,11 @@ public void ConversionExceptionsBehavior() try { enumerator.MoveNext(); + Assert.True(false); } catch { - gotException = true; } - Assert.True(gotException); } } } @@ -400,9 +470,11 @@ public void ConversionMinValueToNullBehavior() { using (var env = new TlcEnvironment()) { - var data = new List(){ - new ConversionLossMinValueClass(){ fSByte = null,fInt = null,fLong = null,fShort = null}, - new ConversionLossMinValueClass(){fSByte = sbyte.MinValue,fInt = int.MinValue,fLong = long.MinValue,fShort = short.MinValue} + + var data = new List + { + new ConversionLossMinValueClass() { fSByte = null, fInt = null, fLong = null, fShort = null }, + new ConversionLossMinValueClass() { fSByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } }; foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) { @@ -433,7 +505,7 @@ public void ClassWithConstFieldsConversion() var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); var originalEnumerator = data.GetEnumerator(); while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } } @@ -471,27 +543,48 @@ public class ClassWithNullableArrays } [Fact] - public void BackAndForthConversionWithArrays() + public void RoundTripConversionWithArrays() { - var data = new List() + + var data = new List { - new ClassWithArrays(){ fInt = new int[3]{ 0,1,2}, fFloat = new float[3]{ -0.99f, 0f, 0.99f}, fString =new string[2]{ "hola", "lola"}, - fBool =new bool[2]{true, false }, fByte = new byte[3]{ 0,124,255}, fDouble=new double[3]{ -1,0, 1}, fLong = new long[]{ 0,1,2} , - fsByte = new sbyte[3]{ -127,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue}, - fuLong = new ulong[2]{ ulong.MaxValue, 0}, fuShort = new ushort[2]{ 0, ushort.MaxValue} + new ClassWithArrays() + { + fInt = new int[3] { 0, 1, 2 }, + fFloat = new float[3] { -0.99f, 0f, 0.99f }, + fString = new string[2] { "hola", "lola" }, + fBool = new bool[2] { true, false }, + fByte = new byte[3] { 0, 124, 255 }, + fDouble = new double[3] { -1, 0, 1 }, + fLong = new long[] { 0, 1, 2 }, + fsByte = new sbyte[3] { -127, 127, 0 }, + fShort = new short[3] { 0, 1225, 32767 }, + fuInt = new uint[2] { 0, uint.MaxValue }, + fuLong = new ulong[2] { ulong.MaxValue, 0 }, + fuShort = new ushort[2] { 0, ushort.MaxValue } }, - new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{"",null} }, + new ClassWithArrays() { fInt = new int[3] { -2, 1, 0 }, fFloat = new float[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "", null } }, new ClassWithArrays() }; - var nullableData = new List() + var nullableData = new List { - new ClassWithNullableArrays(){ fInt = new int?[3]{ null,-1,1}, fFloat = new float?[3]{ -0.99f, null, 0.99f}, fString =new string[2]{ null, ""}, - fBool =new bool?[3]{true,null, false }, fByte = new byte?[4]{ 0,125,null,255}, fDouble=new double?[3]{ -1,null, 1}, fLong = new long?[]{null,-1,1} , - fsByte = new sbyte?[3]{ -127,127,null}, fShort = new short?[3]{ 0, null, 32767 }, fuInt =new uint?[4]{null,42 ,0, uint.MaxValue}, - fuLong = new ulong?[3]{ ulong.MaxValue, null, 0}, fuShort = new ushort?[3]{ 0,null, ushort.MaxValue} + new ClassWithNullableArrays() + { + fInt = new int?[3] { null, -1, 1 }, + fFloat = new float?[3] { -0.99f, null, 0.99f }, + fString = new string[2] { null, "" }, + fBool = new bool?[3] { true, null, false }, + fByte = new byte?[4] { 0, 125, null, 255 }, + fDouble = new double?[3] { -1, null, 1 }, + fLong = new long?[] { null, -1, 1 }, + fsByte = new sbyte?[3] { -127, 127, null }, + fShort = new short?[3] { 0, null, 32767 }, + fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, + fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, + fuShort = new ushort?[3] { 0, null, ushort.MaxValue } }, - new ClassWithNullableArrays(){ fInt = new int?[3]{ -2,1,0}, fFloat = new float?[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} }, + new ClassWithNullableArrays() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, new ClassWithNullableArrays() }; @@ -502,7 +595,7 @@ public void BackAndForthConversionWithArrays() var originalEnumerator = data.GetEnumerator(); while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); @@ -511,7 +604,7 @@ public void BackAndForthConversionWithArrays() var originalNullalbleEnumerator = nullableData.GetEnumerator(); while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) { - Assert.True(CompareThrougReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); } Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); }