@@ -66,41 +66,59 @@ object MVmult {
6666 case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
6767 case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
6868 }))
69- mvmult_abs (a.length, a(0 ).length, a2)
69+ mvmult_abs0( new RingIntPExpr , new VecRStaDyn ( new RingIntPExpr )) (a.length, a(0 ).length, a2)
7070 }
7171 }
7272 }
7373
74- private def mvmult_abs (n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
74+ def mvmult_opt (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
75+ val n = a.length
76+ val m = a(0 ).length
77+ import Lifters ._
78+ ' {
79+ val arr = ~ a.toExpr
80+ ~ {
81+ val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
82+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
83+ case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
84+ case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
85+ }))
86+ mvmult_abs0(new RingIntOPExpr , new VecRStaDyn (new RingIntPExpr ))(a.length, a(0 ).length, a2)
87+ }
88+ }
89+ }
90+
91+ def mvmult_roll (a : Array [Array [Int ]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
92+ val n = a.length
93+ val m = a(0 ).length
94+ import Lifters ._
95+ ' {
96+ val arr = ~ a.toExpr
97+ ~ {
98+ val a2 : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]] = Vec (Sta (n), i => Vec (Sta (m), j => (i, j) match {
99+ case (Sta (i), Sta (j)) => Sta (a(i)(j))
100+ case (Sta (i), Dyn (j)) => Dyn ('(arr(~i.toExpr)(~j)))
101+ case (i, j) => Dyn ( ' { arr(~ (Dyns .dyni(i)))(~ (Dyns .dyni(j))) })
102+ }))
103+ mvmult_abs0(new RingIntOPExpr , new VecRStaOptDynInt (new RingIntPExpr ))(a.length, a(0 ).length, a2)
104+ }
105+ }
106+ }
107+
108+ private def mvmult_abs0 (ring : Ring [PV [Int ]], vecOp : VecROp [PV [Int ], PV [Int ], Expr [Unit ]])(n : Int , m : Int , a : Vec [PV [Int ], Vec [PV [Int ], PV [Int ]]]): Expr [(Array [Int ], Array [Int ]) => Unit ] = {
75109 ' {
76110 (vout, v) => {
77111 if (~ n.toExpr != vout.length) throw new IndexOutOfBoundsException (~ n.toString.toExpr)
78112 if (~ m.toExpr != v.length) throw new IndexOutOfBoundsException (~ m.toString.toExpr)
79113 ~ {
80114 val vout_ : OVec [PV [Int ], PV [Int ], Expr [Unit ]] = OVec (Sta (n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x)))
81115 val v_ : Vec [PV [Int ], PV [Int ]] = Vec (Sta (m), i => Dyn ('(v(~Dyns.dyni(i)))))
82- val MV = new MVmult [PV [Int ], PV [Int ], Expr [Unit ]](new RingIntPExpr , new VecRStaDyn ( new RingIntPExpr ) )
116+ val MV = new MVmult [PV [Int ], PV [Int ], Expr [Unit ]](ring, vecOp )
83117 MV .mvmult(vout_, a, v_)
84118 }
85119 }
86120 }
87121 }
88122
89123
90-
91- // let mvmult_abs : _ →
92- // amat → (float array → float array → unit) code =
93- // fun mvmult → fun {n;m;a} →
94- // .<fun vout v →
95- // assert (n = Array.length vout && m = Array.length v);
96- // .~(let vout = OVec (Sta n, fun i v → .<vout.(.~(dyni i)) ← .~(dynf v)>.) in
97- // let v = Vec (Sta m, fun j → Dyn .<v.(.~(dyni j))>.) in
98- // mvmult vout a v)
99- // >.
100- // val mvmult_abs :
101- // ((int pv, float pv, unit code) Vector.ovec →
102- // (int pv, (int pv, float pv) Vector.vec) Vector.vec →
103- // (int pv, float pv) Vector.vec → unit code) →
104- // amat → (float array → float array → unit) code = <fun>
105-
106124}
0 commit comments