Skip to content

Commit 30ed14f

Browse files
committed
Reuse function in Java UDF to support correctly expression equality comparison
1 parent d749c06 commit 30ed14f

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -488,219 +488,241 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
488488
* @since 1.3.0
489489
*/
490490
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
491+
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
491492
functionRegistry.registerFunction(
492493
name,
493-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e))
494+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
494495
}
495496

496497
/**
497498
* Register a user-defined function with 2 arguments.
498499
* @since 1.3.0
499500
*/
500501
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
502+
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
501503
functionRegistry.registerFunction(
502504
name,
503-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e))
505+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
504506
}
505507

506508
/**
507509
* Register a user-defined function with 3 arguments.
508510
* @since 1.3.0
509511
*/
510512
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
513+
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
511514
functionRegistry.registerFunction(
512515
name,
513-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e))
516+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
514517
}
515518

516519
/**
517520
* Register a user-defined function with 4 arguments.
518521
* @since 1.3.0
519522
*/
520523
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
524+
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
521525
functionRegistry.registerFunction(
522526
name,
523-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e))
527+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
524528
}
525529

526530
/**
527531
* Register a user-defined function with 5 arguments.
528532
* @since 1.3.0
529533
*/
530534
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
535+
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
531536
functionRegistry.registerFunction(
532537
name,
533-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
538+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
534539
}
535540

536541
/**
537542
* Register a user-defined function with 6 arguments.
538543
* @since 1.3.0
539544
*/
540545
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
546+
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
541547
functionRegistry.registerFunction(
542548
name,
543-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
549+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
544550
}
545551

546552
/**
547553
* Register a user-defined function with 7 arguments.
548554
* @since 1.3.0
549555
*/
550556
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
557+
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
551558
functionRegistry.registerFunction(
552559
name,
553-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
560+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
554561
}
555562

556563
/**
557564
* Register a user-defined function with 8 arguments.
558565
* @since 1.3.0
559566
*/
560567
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
568+
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
561569
functionRegistry.registerFunction(
562570
name,
563-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
571+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
564572
}
565573

566574
/**
567575
* Register a user-defined function with 9 arguments.
568576
* @since 1.3.0
569577
*/
570578
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
579+
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
571580
functionRegistry.registerFunction(
572581
name,
573-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
582+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
574583
}
575584

576585
/**
577586
* Register a user-defined function with 10 arguments.
578587
* @since 1.3.0
579588
*/
580589
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
590+
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
581591
functionRegistry.registerFunction(
582592
name,
583-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
593+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
584594
}
585595

586596
/**
587597
* Register a user-defined function with 11 arguments.
588598
* @since 1.3.0
589599
*/
590600
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
601+
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
591602
functionRegistry.registerFunction(
592603
name,
593-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
604+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
594605
}
595606

596607
/**
597608
* Register a user-defined function with 12 arguments.
598609
* @since 1.3.0
599610
*/
600611
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
612+
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
601613
functionRegistry.registerFunction(
602614
name,
603-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
615+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
604616
}
605617

606618
/**
607619
* Register a user-defined function with 13 arguments.
608620
* @since 1.3.0
609621
*/
610622
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
623+
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
611624
functionRegistry.registerFunction(
612625
name,
613-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
626+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
614627
}
615628

616629
/**
617630
* Register a user-defined function with 14 arguments.
618631
* @since 1.3.0
619632
*/
620633
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
634+
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
621635
functionRegistry.registerFunction(
622636
name,
623-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
637+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
624638
}
625639

626640
/**
627641
* Register a user-defined function with 15 arguments.
628642
* @since 1.3.0
629643
*/
630644
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
645+
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
631646
functionRegistry.registerFunction(
632647
name,
633-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
648+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
634649
}
635650

636651
/**
637652
* Register a user-defined function with 16 arguments.
638653
* @since 1.3.0
639654
*/
640655
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
656+
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
641657
functionRegistry.registerFunction(
642658
name,
643-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
659+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
644660
}
645661

646662
/**
647663
* Register a user-defined function with 17 arguments.
648664
* @since 1.3.0
649665
*/
650666
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
667+
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
651668
functionRegistry.registerFunction(
652669
name,
653-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
670+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
654671
}
655672

656673
/**
657674
* Register a user-defined function with 18 arguments.
658675
* @since 1.3.0
659676
*/
660677
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
678+
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
661679
functionRegistry.registerFunction(
662680
name,
663-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
681+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
664682
}
665683

666684
/**
667685
* Register a user-defined function with 19 arguments.
668686
* @since 1.3.0
669687
*/
670688
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
689+
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
671690
functionRegistry.registerFunction(
672691
name,
673-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
692+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
674693
}
675694

676695
/**
677696
* Register a user-defined function with 20 arguments.
678697
* @since 1.3.0
679698
*/
680699
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
700+
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
681701
functionRegistry.registerFunction(
682702
name,
683-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
703+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
684704
}
685705

686706
/**
687707
* Register a user-defined function with 21 arguments.
688708
* @since 1.3.0
689709
*/
690710
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
711+
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
691712
functionRegistry.registerFunction(
692713
name,
693-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
714+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
694715
}
695716

696717
/**
697718
* Register a user-defined function with 22 arguments.
698719
* @since 1.3.0
699720
*/
700721
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
722+
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
701723
functionRegistry.registerFunction(
702724
name,
703-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
725+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
704726
}
705727

706728
// scalastyle:on line.size.limit

sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,19 @@ public void udf3Test() {
108108
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
109109
Assert.assertEquals(9, result.getInt(0));
110110
}
111+
112+
@SuppressWarnings("unchecked")
113+
@Test
114+
public void udf4Test() {
115+
spark.udf().register("inc", new UDF1<Long, Long>() {
116+
@Override
117+
public Long call(Long i) {
118+
return i + 1;
119+
}
120+
}, DataTypes.LongType);
121+
122+
spark.range(10).toDF("x").createOrReplaceTempView("tmp");
123+
Row result = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").head();
124+
Assert.assertEquals(7, result.getLong(0));
125+
}
111126
}

0 commit comments

Comments
 (0)