Skip to content

Commit 62091c9

Browse files
author
Ivan Matantsev
committed
support nullable arrays.
1 parent 581cb42 commit 62091c9

File tree

3 files changed

+107
-13
lines changed

3 files changed

+107
-13
lines changed

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,36 +129,64 @@ private Delegate CreateGetter(int index)
129129
if (outputType.GetElementType() == typeof(string))
130130
{
131131
Ch.Assert(colType.ItemType.IsText);
132-
return CreateArrayGetterDelegate<String, DvText>(index, (x) => new DvText(x));
132+
return CreateArrayGetterDelegate<String, DvText>(index, x => new DvText(x));
133133
}
134134
else if (outputType.GetElementType() == typeof(int))
135135
{
136136
Ch.Assert(colType.ItemType == NumberType.I4);
137-
return CreateArrayGetterDelegate<int, DvInt4>(index, (x) => x);
137+
return CreateArrayGetterDelegate<int, DvInt4>(index, x => x);
138+
}
139+
else if (outputType.GetElementType() == typeof(int?))
140+
{
141+
Ch.Assert(colType.ItemType == NumberType.I4);
142+
return CreateArrayGetterDelegate<int?, DvInt4>(index, x => x ?? DvInt4.NA);
138143
}
139144
else if (outputType.GetElementType() == typeof(long))
140145
{
141146
Ch.Assert(colType.ItemType == NumberType.I8);
142-
return CreateArrayGetterDelegate<long, DvInt8>(index, (x) => x);
147+
return CreateArrayGetterDelegate<long, DvInt8>(index, x => x);
148+
}
149+
else if (outputType.GetElementType() == typeof(long?))
150+
{
151+
Ch.Assert(colType.ItemType == NumberType.I8);
152+
return CreateArrayGetterDelegate<long?, DvInt8>(index, x => x ?? DvInt8.NA);
143153
}
144154
else if (outputType.GetElementType() == typeof(short))
145155
{
146156
Ch.Assert(colType.ItemType == NumberType.I2);
147-
return CreateArrayGetterDelegate<short, DvInt2>(index, (x) => x);
157+
return CreateArrayGetterDelegate<short, DvInt2>(index, x => x);
158+
}
159+
else if (outputType.GetElementType() == typeof(short?))
160+
{
161+
Ch.Assert(colType.ItemType == NumberType.I2);
162+
return CreateArrayGetterDelegate<short?, DvInt2>(index, x => x ?? DvInt2.NA);
148163
}
149164
else if (outputType.GetElementType() == typeof(sbyte))
150165
{
151166
Ch.Assert(colType.ItemType == NumberType.I1);
152-
return CreateArrayGetterDelegate<sbyte, DvInt1>(index, (x) => x);
167+
return CreateArrayGetterDelegate<sbyte, DvInt1>(index, x => x);
168+
}
169+
else if (outputType.GetElementType() == typeof(sbyte?))
170+
{
171+
Ch.Assert(colType.ItemType == NumberType.I1);
172+
return CreateArrayGetterDelegate<sbyte?, DvInt1>(index, x => x ?? DvInt1.NA);
153173
}
154174
else if (outputType.GetElementType() == typeof(bool))
155175
{
156176
Ch.Assert(colType.ItemType.IsBool);
157-
return CreateArrayGetterDelegate<bool, DvBool>(index, (x)=>x);
177+
return CreateArrayGetterDelegate<bool, DvBool>(index, x => x);
178+
}
179+
else if (outputType.GetElementType() == typeof(bool?))
180+
{
181+
Ch.Assert(colType.ItemType.IsBool);
182+
return CreateArrayGetterDelegate<bool?, DvBool>(index, x => x ?? DvBool.NA);
158183
}
159184

160185
// T[] -> VBuffer<T>
161-
Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
186+
if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
187+
Ch.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType);
188+
else
189+
Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType);
162190
del = CreateDirectArrayGetterDelegate<int>;
163191
genericType = outputType.GetElementType();
164192
}

src/Microsoft.ML.Api/TypedCursor.cs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,29 +287,57 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
287287
Ch.Assert(colType.ItemType.IsBool);
288288
return CreateVBufferSetter<DvBool, bool>(input, index, poke, peek, x => Convert.ToBoolean(x.RawValue));
289289
}
290+
else if (fieldType.GetElementType() == typeof(bool?))
291+
{
292+
Ch.Assert(colType.ItemType.IsBool);
293+
return CreateVBufferSetter<DvBool, bool?>(input, index, poke, peek, x => (bool?)x);
294+
}
290295
else if (fieldType.GetElementType() == typeof(int))
291296
{
292297
Ch.Assert(colType.ItemType == NumberType.I4);
293298
return CreateVBufferSetter<DvInt4, int>(input, index, poke, peek, x => (int)x);
294299
}
300+
else if (fieldType.GetElementType() == typeof(int?))
301+
{
302+
Ch.Assert(colType.ItemType == NumberType.I4);
303+
return CreateVBufferSetter<DvInt4, int?>(input, index, poke, peek, x => (int?)x);
304+
}
295305
else if (fieldType.GetElementType() == typeof(short))
296306
{
297307
Ch.Assert(colType.ItemType == NumberType.I2);
298308
return CreateVBufferSetter<DvInt2, short>(input, index, poke, peek, x => (short)x);
299309
}
310+
else if (fieldType.GetElementType() == typeof(short?))
311+
{
312+
Ch.Assert(colType.ItemType == NumberType.I2);
313+
return CreateVBufferSetter<DvInt2, short?>(input, index, poke, peek, x => (short?)x);
314+
}
300315
else if (fieldType.GetElementType() == typeof(long))
301316
{
302317
Ch.Assert(colType.ItemType == NumberType.I8);
303318
return CreateVBufferSetter<DvInt8, long>(input, index, poke, peek, x => (long)x);
304319
}
320+
else if (fieldType.GetElementType() == typeof(long?))
321+
{
322+
Ch.Assert(colType.ItemType == NumberType.I8);
323+
return CreateVBufferSetter<DvInt8, long?>(input, index, poke, peek, x => (long?)x);
324+
}
305325
else if (fieldType.GetElementType() == typeof(sbyte))
306326
{
307327
Ch.Assert(colType.ItemType == NumberType.I1);
308328
return CreateVBufferSetter<DvInt1, sbyte>(input, index, poke, peek, x => (sbyte)x);
309329
}
330+
else if (fieldType.GetElementType() == typeof(sbyte?))
331+
{
332+
Ch.Assert(colType.ItemType == NumberType.I1);
333+
return CreateVBufferSetter<DvInt1, sbyte?>(input, index, poke, peek, x => (sbyte?)x);
334+
}
310335

311336
// VBuffer<T> -> T[]
312-
Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType);
337+
if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>))
338+
Ch.Assert(colType.ItemType.RawType == Nullable.GetUnderlyingType(fieldType.GetElementType()));
339+
else
340+
Ch.Assert(colType.ItemType.RawType == fieldType.GetElementType());
313341
del = CreateVBufferDirectSetter<int>;
314342
genericType = fieldType.GetElementType();
315343
}
@@ -353,7 +381,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
353381
{
354382
Ch.Assert(colType == NumberType.I4);
355383
Ch.Assert(peek == null);
356-
return CreateActionSetter<DvInt4, int?>(input, index, poke, x => x.IsNA ? (int?)null : (int)x);
384+
return CreateActionSetter<DvInt4, int?>(input, index, poke, x => (int?)x);
357385
}
358386
else if (fieldType == typeof(short))
359387
{
@@ -365,7 +393,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
365393
{
366394
Ch.Assert(colType == NumberType.I2);
367395
Ch.Assert(peek == null);
368-
return CreateActionSetter<DvInt2, short?>(input, index, poke, x => x.IsNA ? (short?)null : (short)x);
396+
return CreateActionSetter<DvInt2, short?>(input, index, poke, x => (short?)x);
369397
}
370398
else if (fieldType == typeof(long))
371399
{
@@ -377,7 +405,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
377405
{
378406
Ch.Assert(colType == NumberType.I8);
379407
Ch.Assert(peek == null);
380-
return CreateActionSetter<DvInt8, long?>(input, index, poke, x => x.IsNA ? (long?)null : (long)x);
408+
return CreateActionSetter<DvInt8, long?>(input, index, poke, x => (long?)x);
381409
}
382410
else if (fieldType == typeof(sbyte))
383411
{
@@ -389,7 +417,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
389417
{
390418
Ch.Assert(colType == NumberType.I1);
391419
Ch.Assert(peek == null);
392-
return CreateActionSetter<DvInt1, sbyte?>(input, index, poke, x => x.IsNA ? (sbyte?)null : (sbyte)x);
420+
return CreateActionSetter<DvInt1, sbyte?>(input, index, poke, x => (sbyte?)x);
393421
}
394422
// T -> T
395423
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>))

test/Microsoft.ML.Tests/CollectionDataSourceTests.cs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,22 @@ public class ClassWithArrays
459459
public double[] fDouble;
460460
public bool[] fBool;
461461
}
462+
public class ClassWithNullableArrays
463+
{
464+
public string[] fString;
465+
public int?[] fInt;
466+
public uint?[] fuInt;
467+
public short?[] fShort;
468+
public ushort?[] fuShort;
469+
public sbyte?[] fsByte;
470+
public byte?[] fByte;
471+
public long?[] fLong;
472+
public ulong?[] fuLong;
473+
public float?[] fFloat;
474+
public double?[] fDouble;
475+
public bool?[] fBool;
476+
}
477+
462478

463479
[Fact]
464480
public void BackAndForthConversionWithArrays()
@@ -470,9 +486,20 @@ public void BackAndForthConversionWithArrays()
470486
fsByte = new sbyte[3]{ -127,127,0}, fShort = new short[3]{ 0, 1225, 32767 }, fuInt =new uint[2]{ 0, uint.MaxValue},
471487
fuLong = new ulong[2]{ ulong.MaxValue, 0}, fuShort = new ushort[2]{ 0, ushort.MaxValue}
472488
},
473-
new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} }
489+
new ClassWithArrays(){ fInt = new int[3]{ -2,1,0}, fFloat = new float[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{"",null} },
490+
new ClassWithArrays()
474491
};
475492

493+
var nullableData = new List<ClassWithNullableArrays>()
494+
{
495+
new ClassWithNullableArrays(){ fInt = new int?[3]{ null,-1,1}, fFloat = new float?[3]{ -0.99f, null, 0.99f}, fString =new string[2]{ null, ""},
496+
fBool =new bool?[3]{true,null, false }, fByte = new byte?[4]{ 0,125,null,255}, fDouble=new double?[3]{ -1,null, 1}, fLong = new long?[]{null,-1,1} ,
497+
fsByte = new sbyte?[3]{ -127,127,null}, fShort = new short?[3]{ 0, null, 32767 }, fuInt =new uint?[4]{null,42 ,0, uint.MaxValue},
498+
fuLong = new ulong?[3]{ ulong.MaxValue, null, 0}, fuShort = new ushort?[3]{ 0,null, ushort.MaxValue}
499+
},
500+
new ClassWithNullableArrays(){ fInt = new int?[3]{ -2,1,0}, fFloat = new float?[3]{ 0.99f, 0f, -0.99f}, fString =new string[2]{ "lola","hola"} },
501+
new ClassWithNullableArrays()
502+
};
476503
using (var env = new TlcEnvironment())
477504
{
478505
var dataView = ComponentCreation.CreateDataView(env, data);
@@ -483,7 +510,18 @@ public void BackAndForthConversionWithArrays()
483510
Assert.True(CompareThrougReflection(enumeratorSimple.Current, originalEnumerator.Current));
484511
}
485512
Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
513+
514+
var nullableDataView = ComponentCreation.CreateDataView(env, nullableData);
515+
var enumeratorNullable = nullableDataView.AsEnumerable<ClassWithNullableArrays>(env, false).GetEnumerator();
516+
var originalNullalbleEnumerator = nullableData.GetEnumerator();
517+
while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext())
518+
{
519+
Assert.True(CompareThrougReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current));
520+
}
521+
Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext());
486522
}
487523
}
524+
525+
488526
}
489527
}

0 commit comments

Comments
 (0)