@@ -55,56 +55,84 @@ object MVmult {
5555 }
5656
5757 def mvmult_ac (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
58- val n = a.length
59- val m = a(0 ).length
6058 import Lifters ._
6159 ' {
6260 val arr = ~ a.toExpr
6361 ~ {
64- val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
65- case (Sta (i), Sta (j)) => Sta (a(i)(j))
66- case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
67- case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
68- }))
69- mvmult_abs0(new RingIntPExpr , new VecRStaDyn (new RingIntPExpr ))(a.length, a(0 ).length, a2)
62+ val (n, m, a2) = amat1(a, '(arr))
63+ mvmult_abs0(new RingIntPExpr , new VecRStaDyn (new RingIntPExpr ))(n, m, a2)
7064 }
7165 }
7266 }
7367
7468 def mvmult_opt (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
75- val n = a.length
76- val m = a(0 ).length
7769 import Lifters ._
7870 ' {
7971 val arr = ~ a.toExpr
8072 ~ {
81- val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
82- case (Sta (i), Sta (j)) => Sta (a(i)(j))
83- case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
84- case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
85- }))
86- mvmult_abs0(new RingIntOPExpr , new VecRStaDyn (new RingIntPExpr ))(a.length, a(0 ).length, a2)
73+ val (n, m, a2) = amat1(a, '(arr))
74+ mvmult_abs0(new RingIntOPExpr , new VecRStaDyn (new RingIntPExpr ))(n, m, a2)
8775 }
8876 }
8977 }
9078
9179 def mvmult_roll (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
92- val n = a.length
93- val m = a(0 ).length
9480 import Lifters ._
9581 ' {
9682 val arr = ~ a.toExpr
9783 ~ {
98- val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
99- case (Sta (i), Sta (j)) => Sta (a(i)(j))
100- case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
101- case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
102- }))
103- mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(a.length, a(0 ).length, a2)
84+ val (n, m, a2) = amat1(a, '(arr))
85+ mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(n, m, a2)
10486 }
10587 }
10688 }
10789
90+ def mvmult_let1 (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
91+ val (n, m, a2) = amatCopy(a, copy_row1)
92+ mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(n, m, a2)
93+ }
94+
95+ def mvmult_let (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
96+ val (n, m, a2) = amatCopy(a, copy_row_let)
97+ mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(n, m, a2)
98+ }
99+
100+ def amat1 (a : Array [Array [Int ]], aa : Expr [Array [Array [Int ]]]): (Int , Int , Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]) = {
101+ val n = a.length
102+ val m = a(0 ).length
103+ val vec : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
104+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
105+ case (Sta (i), Dyn (j)) => Dyn ('((~aa)(~i.toExpr)(~j)))
106+ case (i, j) => Dyn (' { (~ aa)(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
107+ }))
108+ (n, m, vec)
109+ }
110+
111+ def amatCopy (a : Array [Array [Int ]], copyRow : Array [Int ] => (Expr [Int ] => Expr [Int ])): (Int , Int , Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]) = {
112+ val n = a.length
113+ val m = a(0 ).length
114+ val vec : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
115+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
116+ case (Sta (i), Dyn (j)) =>
117+ val defrec = copyRow(a(i))
118+ Dyn (defrec(j))
119+ case (i, j) => ???
120+ }))
121+ (n, m, vec)
122+ }
123+
124+ def copy_row1 : Array [Int ] => (Expr [Int ] => Expr [Int ]) = v => {
125+ import Lifters ._
126+ val arr = v.toExpr
127+ i => ' { (~ arr).apply(~ i) }
128+ }
129+
130+ def copy_row_let : Array [Int ] => (Expr [Int ] => Expr [Int ]) = v => {
131+ import Lifters ._
132+ val arr : Expr [Array [Int ]] = ??? // FIXME used genlet v.toExpr
133+ i => ' { (~ arr).apply(~ i) }
134+ }
135+
108136 private def mvmult_abs0 (ring : Ring [PV [Int ]], vecOp : VecROp [PV [Int ], PV [Int ], Expr [Unit ]])(n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
109137 ' {
110138 (vout, v) => {
@@ -120,5 +148,4 @@ object MVmult {
120148 }
121149 }
122150
123-
124151}
0 commit comments