Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.stat

import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericMutableRow(columnSize + 1)
rows.foreach { row =>
rows.foreach { (row: Row) =>
// row.get(0) is column 1
// row.get(1) is column 2
// row.get(3) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
Expand All @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class DataFrameStatSuite extends SparkFunSuite {
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
assert(rows(0).get(2) === null)
assert(rows(0).get(2) === 0L)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
assert(rows(1).get(2) === null)
assert(rows(1).get(2) === 0L)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)
Expand Down