diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs index 461badeda3..2b6cde89c0 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.cs @@ -1955,7 +1955,7 @@ internal PrimitiveDataFrameColumn ElementwiseEqualsImplementation(U val { throw new NotSupportedException(); } - return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.ElementwiseEquals(Unsafe.As(ref value))); + return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.ElementwiseEquals(Unsafe.As(ref value))); case Type byteType when byteType == typeof(byte): case Type charType when charType == typeof(char): case Type doubleType when doubleType == typeof(double): @@ -2102,7 +2102,7 @@ internal PrimitiveDataFrameColumn ElementwiseNotEqualsImplementation(U { throw new NotSupportedException(); } - return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.ElementwiseNotEquals(Unsafe.As(ref value))); + return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.ElementwiseNotEquals(Unsafe.As(ref value))); case Type byteType when byteType == typeof(byte): case Type charType when charType == typeof(char): case Type doubleType when doubleType == typeof(double): diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.tt b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.tt index 66d04d02ec..356b70eb72 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.tt +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.BinaryOperations.tt @@ -190,7 +190,7 @@ namespace Microsoft.Data.Analysis throw new NotSupportedException(); } <# if (method.MethodType == MethodType.ComparisonScalar) { #> - return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.<#=method.MethodName#>(Unsafe.As(ref value))); + return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn<<#=type.TypeName#>>)._columnContainer.<#=method.MethodName#>(Unsafe.As>(ref value))); <# } else if (method.MethodType == MethodType.Comparison) { #> return new BooleanDataFrameColumn(Name, (this as PrimitiveDataFrameColumn)._columnContainer.<#=method.MethodName#>(column._columnContainer)); <# } else if (method.MethodType == MethodType.BinaryScalar) {#> diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index b02319e478..85e4ccd79c 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -625,9 +625,17 @@ public void TestBinaryOperationsOnDateTimeColumn() Assert.True(equalsResult[0]); Assert.False(equalsResult[4]); + var equalsToScalarResult = df["DateTime1"].ElementwiseEquals(SampleDateTime); + Assert.True(equalsToScalarResult[0]); + Assert.False(equalsToScalarResult[1]); + var notEqualsResult = dataFrameColumn1.ElementwiseNotEquals(dataFrameColumn2); Assert.False(notEqualsResult[0]); Assert.True(notEqualsResult[4]); + + var notEqualsToScalarResult = df["DateTime1"].ElementwiseNotEquals(SampleDateTime); + Assert.False(notEqualsToScalarResult[0]); + Assert.True(notEqualsToScalarResult[1]); } [Fact]