@@ -27,6 +27,20 @@ object Test {
2727
2828 val code4 = ' { (arr : Array [Int ], f : Int => Unit ) => ~ foreach4('(arr), ' (f), 4 ) }
2929 println(code4.show)
30+ println()
31+
32+ val liftedArray = Array (1 , 2 , 3 , 4 ).toExpr
33+ println(liftedArray.show)
34+ println()
35+
36+
37+ def printAll (arr : Array [Int ]) = ' {
38+ val arr1 = ~ arr.toExpr
39+ ~ foreach1('(arr1), ' (x => println(x)))
40+ }
41+
42+ println(printAll(Array (1 , 3 , 4 , 5 )).show)
43+
3044 }
3145
3246 def foreach1 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ]): Expr [Unit ] = ' {
@@ -81,19 +95,40 @@ object Test {
8195 }
8296 }
8397
98+ def foreach3_2 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ]): Expr [Unit ] = ' {
99+ val size = (~ arrRef).length
100+ var i = 0
101+ if (size % 3 != 0 ) throw new Exception (" ..." )// for simplicity of the implementation
102+ while (i < size) {
103+ (~ f)((~ arrRef)(i))
104+ (~ f)((~ arrRef)(i + 1 ))
105+ (~ f)((~ arrRef)(i + 2 ))
106+ i += 3
107+ }
108+ }
109+
84110 def foreach4 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ], unrollSize : Int ): Expr [Unit ] = ' {
85111 val size = (~ arrRef).length
86112 var i = 0
87113 if (size % ~ unrollSize.toExpr != 0 ) throw new Exception (" ..." ) // for simplicity of the implementation
88114 while (i < size) {
89- ~ {
90- @ tailrec def loop (j : Int , acc : Expr [Unit ]): Expr [Unit ] =
91- if (j >= 0 ) loop(j - 1 , ' { (~ f)((~ arrRef)(i + ~ j.toExpr)); ~ acc })
92- else acc
93- loop(unrollSize - 1 , '())
94- }
115+ ~ foreachInRange(0 , unrollSize)(j => ' { (~ f)((~ arrRef)(i + ~ j.toExpr)) })
95116 i += ~ unrollSize.toExpr
96117 }
97118 }
98119
120+ implicit object ArrayIntIsLiftable extends Liftable [Array [Int ]] {
121+ override def toExpr (x : Array [Int ]): Expr [Array [Int ]] = ' {
122+ val array = new Array [Int ](~ x.length.toExpr)
123+ ~ foreachInRange(0 , x.length)(i => ' { array(~ i.toExpr) = ~ x(i).toExpr})
124+ array
125+ }
126+ }
127+
128+ def foreachInRange (start : Int , end : Int )(f : Int => Expr [Unit ]): Expr [Unit ] = {
129+ @ tailrec def unroll (i : Int , acc : Expr [Unit ]): Expr [Unit ] =
130+ if (i < end) unroll(i + 1 , ' { ~ acc; ~ f(i) }) else acc
131+ if (start < end) unroll(start + 1 , f(start)) else '()
132+ }
133+
99134}
0 commit comments