--- title: Winograd — Strassen algorithm description: Consider a modification of Strassen's algorithm for square matrix multiplication with fewer number of summations between blocks than in the ordinary... sections: [Multithreading,Block matrices,Comparing algorithms] tags: [java,streams,arrays,multidimensional arrays,matrices,recursion,loops,nested loops] canonical_url: /en/2022/02/11/winograd-strassen-algorithm.html url_translated: /ru/2022/02/10/winograd-strassen-algorithm.html title_translated: Алгоритм Винограда — Штрассена date: 2022.02.11 lang: en --- Consider a modification of Strassen's algorithm for square matrix multiplication with *fewer* number of summations between blocks than in the ordinary algorithm — 15 instead of 18 and the same number of multiplications as in the ordinary algorithm — 7. We will use Java Streams. Recursive partitioning of matrices into blocks during multiplication makes sense up to a certain limit, and then it loses its sense, since the Strassen's algorithm does not use cache of the execution environment. Therefore, for small blocks we will use a parallel version of nested loops, and for large blocks we will perform recursive partitioning in parallel. We determine the boundary between the two algorithms experimentally — we adjust it to the cache of the execution environment. The benefit of Strassen's algorithm becomes more evident on sizable matrices — the difference with the algorithm using nested loops becomes larger and depends on the execution environment. Let's compare the operating time of two algorithms. *Algorithm using three nested loops: [Optimizing matrix multiplication]({{ '/en/2021/12/10/optimizing-matrix-multiplication.html' | relative_url }}).* {% include heading.html text="Algorithm description" hash="algorithm-description" %} Matrices must be the same size. We partition each matrix into 4 equally sized blocks. The blocks must be square, therefore if this is not the case, then first we supplement the matrices with zero rows and columns, and after that partition them into blocks. We will remove the redundant rows and columns later from the resulting matrix. {% include image_svg.html src="/img/block-matrices.svg" style="width:221pt; height:33pt;" alt="{\displaystyle A={\begin{pmatrix}A_{11}&A_{12}\\A_{21}&A_{22}\end{pmatrix}},\quad B={\begin{pmatrix}B_{11}&B_{12}\\B_{21}&B_{22}\end{pmatrix}}.}" %} Summation of blocks. {% include image_svg.html src="/img/sums1.svg" style="width:101pt; height:148pt;" alt="{\displaystyle{\begin{aligned}S_{1}&=(A_{21}+A_{22});\\S_{2}&=(S_{1}-A_{11});\\S_{3}&=(A_{11}-A_{21});\\S_{4}&=(A_{12}-S_{2});\\S_{5}&=(B_{12}-B_{11});\\S_{6}&=(B_{22}-S_{5});\\S_{7}&=(B_{22}-B_{12});\\S_{8}&=(S_{6}-B_{21}).\end{aligned}}}" %} Multiplication of blocks. {% include image_svg.html src="/img/products.svg" style="width:75pt; height:127pt;" alt="{\displaystyle{\begin{aligned}P_{1}&=S_{2}S_{6};\\P_{2}&=A_{11}B_{11};\\P_{3}&=A_{12}B_{21};\\P_{4}&=S_{3}S_{7};\\P_{5}&=S_{1}S_{5};\\P_{6}&=S_{4}B_{22};\\P_{7}&=A_{22}S_{8}.\end{aligned}}}" %} Summation of blocks. {% include image_svg.html src="/img/sums2.svg" style="width:78pt; height:31pt;" alt="{\displaystyle{\begin{aligned}T_{1}&=P_{1}+P_{2};\\T_{2}&=T_{1}+P_{4}.\end{aligned}}}" %} Blocks of the resulting matrix. {% include image_svg.html src="/img/sums3.svg" style="width:240pt; height:33pt;" alt="{\displaystyle{\begin{pmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{pmatrix}}={\begin{pmatrix}P_{2}+P_{3}&T_{1}+P_{5}+P_{6}\\T_{2}-P_{7}&T_{2}+P_{5}\end{pmatrix}}.}" %} {% include heading.html text="Hybrid algorithm" hash="hybrid-algorithm" %} We partition each matrix `A` and `B` into 4 equally sized blocks and, if necessary, we supplement the missing parts with zeros. Perform 15 summations and 7 multiplications over the blocks — we get 4 blocks of the matrix `C`. Remove the redundant zeros, if added, and return the resulting matrix. We run recursive partitioning of large blocks in parallel mode, and for small blocks we call the algorithm with nested loops. ```java /** * @param n matrix size * @param brd minimum matrix size * @param a first matrix 'n×n' * @param b second matrix 'n×n' * @return resulting matrix 'n×n' */ public static int[][] multiplyMatrices(int n, int brd, int[][] a, int[][] b) { // multiply small blocks using algorithm with nested loops if (n < brd) return simpleMultiplication(n, a, b); // midpoint of the matrix, round up — blocks should // be square, if necessary add zero rows and columns int m = n - n / 2; // blocks of the first matrix int[][] a11 = getQuadrant(m, n, a, true, true); int[][] a12 = getQuadrant(m, n, a, true, false); int[][] a21 = getQuadrant(m, n, a, false, true); int[][] a22 = getQuadrant(m, n, a, false, false); // blocks of the second matrix int[][] b11 = getQuadrant(m, n, b, true, true); int[][] b12 = getQuadrant(m, n, b, true, false); int[][] b21 = getQuadrant(m, n, b, false, true); int[][] b22 = getQuadrant(m, n, b, false, false); // summation of blocks int[][] s1 = sumMatrices(m, a21, a22, true); int[][] s2 = sumMatrices(m, s1, a11, false); int[][] s3 = sumMatrices(m, a11, a21, false); int[][] s4 = sumMatrices(m, a12, s2, false); int[][] s5 = sumMatrices(m, b12, b11, false); int[][] s6 = sumMatrices(m, b22, s5, false); int[][] s7 = sumMatrices(m, b22, b12, false); int[][] s8 = sumMatrices(m, s6, b21, false); int[][][] p = new int[7][][]; // multiplication of blocks in parallel streams IntStream.range(0, 7).parallel().forEach(i -> { switch (i) { // recursive calls case 0: p[i] = multiplyMatrices(m, brd, s2, s6); break; case 1: p[i] = multiplyMatrices(m, brd, a11, b11); break; case 2: p[i] = multiplyMatrices(m, brd, a12, b21); break; case 3: p[i] = multiplyMatrices(m, brd, s3, s7); break; case 4: p[i] = multiplyMatrices(m, brd, s1, s5); break; case 5: p[i] = multiplyMatrices(m, brd, s4, b22); break; case 6: p[i] = multiplyMatrices(m, brd, a22, s8); break; } }); // summation of blocks int[][] t1 = sumMatrices(m, p[0], p[1], true); int[][] t2 = sumMatrices(m, t1, p[3], true); // blocks of the resulting matrix int[][] c11 = sumMatrices(m, p[1], p[2], true); int[][] c12 = sumMatrices(m, t1, sumMatrices(m, p[4], p[5], true), true); int[][] c21 = sumMatrices(m, t2, p[6], false); int[][] c22 = sumMatrices(m, t2, p[4], true); // assemble a matrix from blocks, // remove zero rows and columns, if added return putQuadrants(m, n, c11, c12, c21, c22); } ``` {% capture collapsed_md %} ```java // helper method for matrix summation private static int[][] sumMatrices(int n, int[][] a, int[][] b, boolean sign) { int[][] c = new int[n][n]; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) c[i][j] = sign ? a[i][j] + b[i][j] : a[i][j] - b[i][j]; return c; } ``` ```java // helper method, gets a block of a matrix private static int[][] getQuadrant(int m, int n, int[][] x, boolean first, boolean second) { int[][] q = new int[m][m]; if (first) for (int i = 0; i < m; i++) if (second) System.arraycopy(x[i], 0, q[i], 0, m); // x11 else System.arraycopy(x[i], m, q[i], 0, n - m); // x12 else for (int i = m; i < n; i++) if (second) System.arraycopy(x[i], 0, q[i - m], 0, m); // x21 else System.arraycopy(x[i], m, q[i - m], 0, n - m); // x22 return q; } ``` ```java // helper method, assembles a matrix from blocks private static int[][] putQuadrants(int m, int n, int[][] x11, int[][] x12, int[][] x21, int[][] x22) { int[][] x = new int[n][n]; for (int i = 0; i < n; i++) if (i < m) { System.arraycopy(x11[i], 0, x[i], 0, m); System.arraycopy(x12[i], 0, x[i], m, n - m); } else { System.arraycopy(x21[i - m], 0, x[i], 0, m); System.arraycopy(x22[i - m], 0, x[i], m, n - m); } return x; } ``` {% endcapture %} {%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%} {% include heading.html text="Nested loops" hash="nested-loops" %} To supplement the previous algorithm and to compare with it, we take the *optimized* variant of nested loops, that uses cache of the execution environment better than others — processing of the rows of the resulting matrix occurs independently of each other in parallel streams. For small matrices, we use this algorithm — large matrices we partition into small blocks and use the same algorithm. ```java /** * @param n matrix size * @param a first matrix 'n×n' * @param b second matrix 'n×n' * @return resulting matrix 'n×n' */ public static int[][] simpleMultiplication(int n, int[][] a, int[][] b) { // the resulting matrix int[][] c = new int[n][n]; // bypass the rows of matrix 'a' in parallel mode IntStream.range(0, n).parallel().forEach(i -> { // bypass the indexes of the common side of two matrices: // the columns of matrix 'a' and the rows of matrix 'b' for (int k = 0; k < n; k++) // bypass the indexes of the columns of matrix 'b' for (int j = 0; j < n; j++) // the sum of the products of the elements of the i-th // row of matrix 'a' and the j-th column of matrix 'b' c[i][j] += a[i][k] * b[k][j]; }); return c; } ``` {% include heading.html text="Testing" hash="testing" %} To check, we take two square matrices `A=[1000×1000]` and `B=[1000×1000]`, filled with random numbers. Take the minimum block size `[200×200]` elements. First, we compare the correctness of the implementation of the two algorithms — matrix products must match. Then we execute each method 10 times and calculate the average execution time. ```java // start the program and output the result public static void main(String[] args) { // incoming data int n = 1000, brd = 200, steps = 10; int[][] a = randomMatrix(n, n), b = randomMatrix(n, n); // matrix products int[][] c1 = multiplyMatrices(n, brd, a, b); int[][] c2 = simpleMultiplication(n, a, b); // check the correctness of the results System.out.println("The results match: " + Arrays.deepEquals(c1, c2)); // measure the execution time of two methods benchmark("Hybrid algorithm", steps, () -> { int[][] c = multiplyMatrices(n, brd, a, b); if (!Arrays.deepEquals(c, c1)) System.out.print("error"); }); benchmark("Nested loops ", steps, () -> { int[][] c = simpleMultiplication(n, a, b); if (!Arrays.deepEquals(c, c2)) System.out.print("error"); }); } ``` {% capture collapsed_md %} ```java // helper method, returns a matrix of the specified size private static int[][] randomMatrix(int row, int col) { int[][] matrix = new int[row][col]; for (int i = 0; i < row; i++) for (int j = 0; j < col; j++) matrix[i][j] = (int) (Math.random() * row * col); return matrix; } ``` ```java // helper method for measuring the execution time of the passed code private static void benchmark(String title, int steps, Runnable runnable) { long time, avg = 0; System.out.print(title); for (int i = 0; i < steps; i++) { time = System.currentTimeMillis(); runnable.run(); time = System.currentTimeMillis() - time; // execution time of one step System.out.print(" | " + time); avg += time; } // average execution time System.out.println(" || " + (avg / steps)); } ``` {% endcapture %} {%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%} Output depends on the execution environment, time in milliseconds: ``` The results match: true Hybrid algorithm | 196 | 177 | 156 | 205 | 154 | 165 | 133 | 118 | 132 | 134 || 157 Nested loops | 165 | 164 | 168 | 167 | 168 | 168 | 170 | 179 | 173 | 168 || 169 ``` {% include heading.html text="Comparing algorithms" hash="comparing-algorithms" %} On an eight-core Linux x64 computer, execute the above test 100 times instead of 10. Take the minimum block size `[brd=200]` elements. Change only `n` — sizes of both matrices `A=[n×n]` and `B=[n×n]`. Get a summary table of results. Time in milliseconds. ``` n | 900 | 1000 | 1100 | 1200 | 1300 | 1400 | 1500 | 1600 | 1700 | -----------------|-----|------|------|------|------|------|------|------|------| Hybrid algorithm | 96 | 125 | 169 | 204 | 260 | 313 | 384 | 482 | 581 | Nested loops | 119 | 162 | 235 | 281 | 361 | 497 | 651 | 793 | 971 | ``` Results: the benefit of the Strassen algorithm becomes more evident on large matrices, when the size of the matrix itself is several times larger than the size of the minimal block, and depends on the execution environment. All the methods described above, including collapsed blocks, can be placed in one class. {% capture collapsed_md %} ```java import java.util.Arrays; import java.util.stream.IntStream; ``` {% endcapture %} {%- include collapsed_block.html summary="Required imports" content=collapsed_md -%}