Skip to content

Commit ad6386d

Browse files
committed
Use fma following feedback from Adam Pocock <[email protected]>
1 parent 0d41a03 commit ad6386d

File tree

3 files changed

+49
-43
lines changed

3 files changed

+49
-43
lines changed

blas/src/main/java/dev/ludovic/blas/VectorizedBLAS.java

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int inc
4242
if (incx == 1 && incy == 1 && n <= x.length && n <= y.length) {
4343
if (alpha != 0.) {
4444
int i = 0;
45+
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
4546
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
4647
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
4748
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i);
48-
vx.lanewise(VectorOperators.MUL, alpha).add(vy)
49-
.intoArray(y, i);
49+
vx.fma(valpha, vy).intoArray(y, i);
5050
}
5151
for (; i < n; i += 1) {
5252
y[i] += alpha * x[i];
@@ -65,11 +65,13 @@ public float sdot(int n, float[] x, int incx, float[] y, int incy) {
6565
if (incx == 1 && incy == 1) {
6666
float sum = 0.0f;
6767
int i = 0;
68+
FloatVector vsum = FloatVector.zero(FMAX);
6869
for (; i < FMAX.loopBound(n); i += FMAX.length()) {
6970
FloatVector vx = FloatVector.fromArray(FMAX, x, i);
7071
FloatVector vy = FloatVector.fromArray(FMAX, y, i);
71-
sum += vx.mul(vy).reduceLanes(VectorOperators.ADD);
72+
vsum = vx.fma(vy, vsum);
7273
}
74+
sum += vsum.reduceLanes(VectorOperators.ADD);
7375
for (; i < n; i += 1) {
7476
sum += x[i] * y[i];
7577
}
@@ -87,11 +89,13 @@ public double ddot(int n, double[] x, int incx, double[] y, int incy) {
8789
if (incx == 1 && incy == 1) {
8890
double sum = 0.;
8991
int i = 0;
92+
DoubleVector vsum = DoubleVector.zero(DMAX);
9093
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
9194
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
9295
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i);
93-
sum += vx.mul(vy).reduceLanes(VectorOperators.ADD);
96+
vsum = vx.fma(vy, vsum);
9497
}
98+
sum += vsum.reduceLanes(VectorOperators.ADD);
9599
for (; i < n; i += 1) {
96100
sum += x[i] * y[i];
97101
}
@@ -109,10 +113,10 @@ public void dscal(int n, double alpha, double[] x, int incx) {
109113
if (incx == 1) {
110114
if (alpha != 1.) {
111115
int i = 0;
116+
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
112117
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
113118
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
114-
vx.lanewise(VectorOperators.MUL, alpha)
115-
.intoArray(x, i);
119+
vx.mul(valpha).intoArray(x, i);
116120
}
117121
for (; i < n; i += 1) {
118122
x[i] *= alpha;
@@ -134,19 +138,21 @@ public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int
134138
dscal(n, beta, y, 1);
135139
// y += alpha * A * x
136140
if (alpha != 0.) {
137-
for (int col = 0; col < n; col += 1) {
138-
int row = 0;
139-
for (; row < DMAX.loopBound(col + 1); row += DMAX.length()) {
140-
DoubleVector vx = DoubleVector.fromArray(DMAX, x, row);
141-
DoubleVector va = DoubleVector.fromArray(DMAX, a, row + col * (col + 1) / 2);
142-
y[col] += alpha * vx.mul(va).reduceLanes(VectorOperators.ADD);
143-
}
144-
for (; row < col + 1; row += 1) {
145-
y[col] += alpha * x[row] * a[row + col * (col + 1) / 2];
141+
for (int row = 0; row < n; row += 1) {
142+
int col = 0;
143+
DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]);
144+
for (; col < DMAX.loopBound(row); col += DMAX.length()) {
145+
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
146+
DoubleVector vy = DoubleVector.fromArray(DMAX, y, col);
147+
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2);
148+
y[row] += alpha * vx.mul(va).reduceLanes(VectorOperators.ADD);
149+
valphaxrow.fma(va, vy).intoArray(y, col);
146150
}
147-
for (; row < n; row += 1) {
148-
y[col] += alpha * x[row] * a[row * (row + 1) / 2 + col];
151+
for (; col < row; col += 1) {
152+
y[row] += alpha * x[col] * a[col + row * (row + 1) / 2];
153+
y[col] += alpha * x[row] * a[col + row * (row + 1) / 2];
149154
}
155+
y[row] += alpha * x[col] * a[col + row * (row + 1) / 2];
150156
}
151157
}
152158
} else {
@@ -161,16 +167,16 @@ public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[
161167
// uplo, n, alpha, x, incx, a)
162168
if (uplo.equals("U") && incx == 1) {
163169
if (alpha != 0.) {
164-
for (int col = 0; col < n; col += 1) {
165-
int row = 0;
166-
for (; row < DMAX.loopBound(col + 1); row += DMAX.length()) {
167-
DoubleVector vx = DoubleVector.fromArray(DMAX, x, row);
168-
DoubleVector va = DoubleVector.fromArray(DMAX, a, row + col * (col + 1) / 2);
169-
vx.lanewise(VectorOperators.MUL, alpha * x[col]).add(va)
170-
.intoArray(a, row + col * (col + 1) / 2);
170+
for (int row = 0; row < n; row += 1) {
171+
int col = 0;
172+
DoubleVector valphax = DoubleVector.broadcast(DMAX, alpha * x[row]);
173+
for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) {
174+
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
175+
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2);
176+
vx.fma(valphax, va).intoArray(a, col + row * (row + 1) / 2);
171177
}
172-
for (; row < col + 1; row += 1) {
173-
a[row + col * (col + 1) / 2] += alpha * x[col] * x[row];
178+
for (; col < row + 1; col += 1) {
179+
a[col + row * (row + 1) / 2] += alpha * x[row] * x[col];
174180
}
175181
}
176182
}
@@ -186,16 +192,16 @@ public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[
186192
// uplo, n, alpha, x, incx, a, lda)
187193
if (uplo.equals("U") && incx == 1) {
188194
if (alpha != 0.) {
189-
for (int col = 0; col < n; col += 1) {
190-
int row = 0;
191-
for (; row < DMAX.loopBound(col + 1); row += DMAX.length()) {
192-
DoubleVector vx = DoubleVector.fromArray(DMAX, x, row);
193-
DoubleVector va = DoubleVector.fromArray(DMAX, a, row + col * n);
194-
vx.lanewise(VectorOperators.MUL, alpha * x[col]).add(va)
195-
.intoArray(a, row + col * n);
195+
for (int row = 0; row < n; row += 1) {
196+
int col = 0;
197+
DoubleVector valphax = DoubleVector.broadcast(DMAX, alpha * x[row]);
198+
for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) {
199+
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
200+
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * n);
201+
vx.fma(valphax, va).intoArray(a, col + row * n);
196202
}
197-
for (; row < col + 1; row += 1) {
198-
a[row + col * n] += alpha * x[col] * x[row];
203+
for (; col < row + 1; col += 1) {
204+
a[col + row * n] += alpha * x[row] * x[col];
199205
}
200206
}
201207
}

blas/src/test/java/dev/ludovic/blas/DdotTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public class DdotTest extends BLASTest {
2727

2828
@Test
2929
void testSanity() {
30-
int n = 3;
30+
int n = 9;
3131
double[] x = new double[] {
32-
1.0, 0.0, -2.0 };
32+
1.0, 0.0, -2.0, 3.0, 1.0, 0.0, -2.0, 3.0, 3.0 };
3333
double[] y = new double[] {
34-
2.0, 1.0, 0.0 };
34+
2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0 };
3535

36-
assertEquals(2.0, blas.ddot(n, x, 1, y, 1));
36+
assertEquals(4.0, blas.ddot(n, x, 1, y, 1));
3737
}
3838
}

blas/src/test/java/dev/ludovic/blas/SdotTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public class SdotTest extends BLASTest {
2727

2828
@Test
2929
void testSanity() {
30-
int n = 3;
30+
int n = 9;
3131
float[] x = new float[] {
32-
1.0f, 0.0f, -2.0f };
32+
1.0f, 0.0f, -2.0f, 1.0f, 0.0f, -2.0f, 1.0f, 0.0f, -2.0f };
3333
float[] y = new float[] {
34-
2.0f, 1.0f, 0.0f };
34+
2.0f, 1.0f, 0.0f, 2.0f, 1.0f, 0.0f, 2.0f, 1.0f, 0.0f };
3535

36-
assertEquals(2.0f, blas.sdot(n, x, 1, y, 1));
36+
assertEquals(6.0f, blas.sdot(n, x, 1, y, 1));
3737
}
3838
}

0 commit comments

Comments
 (0)