Skip to content

Commit 8be11b0

Browse files
author
Davies Liu
committed
fix bugs
1 parent 3338c89 commit 8be11b0

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class CodegenContext {
160160
* The map from a variable name to it's next ID.
161161
*/
162162
private val freshNameIds = new mutable.HashMap[String, Int]
163+
freshNameIds += INPUT_ROW -> 1
163164

164165
/**
165166
* A prefix used to generate fresh name.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,22 +173,26 @@ case class GetArrayStructFields(
173173
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
174174
val arrayClass = classOf[GenericArrayData].getName
175175
nullSafeCodeGen(ctx, ev, eval => {
176+
val n = ctx.freshName("n")
177+
val values = ctx.freshName("values")
178+
val j = ctx.freshName("j")
179+
val row = ctx.freshName("row")
176180
s"""
177-
final int n = $eval.numElements();
178-
final Object[] values = new Object[n];
179-
for (int j = 0; j < n; j++) {
180-
if ($eval.isNullAt(j)) {
181-
values[j] = null;
181+
final int $n = $eval.numElements();
182+
final Object[] $values = new Object[$n];
183+
for (int $j = 0; $j < $n; $j++) {
184+
if ($eval.isNullAt($j)) {
185+
$values[$j] = null;
182186
} else {
183-
final InternalRow row = $eval.getStruct(j, $numFields);
184-
if (row.isNullAt($ordinal)) {
185-
values[j] = null;
187+
final InternalRow $row = $eval.getStruct($j, $numFields);
188+
if ($row.isNullAt($ordinal)) {
189+
$values[$j] = null;
186190
} else {
187-
values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
191+
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
188192
}
189193
}
190194
}
191-
${ev.value} = new $arrayClass(values);
195+
${ev.value} = new $arrayClass($values);
192196
"""
193197
})
194198
}
@@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
227231

228232
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
229233
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
234+
val index = ctx.freshName("index")
230235
s"""
231-
final int index = (int) $eval2;
232-
if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
236+
final int $index = (int) $eval2;
237+
if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
233238
${ev.isNull} = true;
234239
} else {
235-
${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
240+
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
236241
}
237242
"""
238243
})

0 commit comments

Comments
 (0)