--- title: Optimizing matrix multiplication description: Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be O(n³), but there are... sections: [Permutations,Nested loops,Comparing algorithms] tags: [java,arrays,multidimensional arrays,matrices,rows,columns,layers,loops] canonical_url: /en/2021/12/10/optimizing-matrix-multiplication.html url_translated: /ru/2021/12/09/optimizing-matrix-multiplication.html title_translated: Оптимизация умножения матриц date: 2021.12.10 lang: en --- Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be `O(n³)`, but there are particularities related to the execution environment — the speed of the algorithm depends on the sequence in which the loops are executed. Let's compare different permutations of nested loops and the execution time of the algorithms. Let's take two matrices: {`L×M`} and {`M×N`} → three loops → six permutations: `LMN`, `LNM`, `MLN`, `MNL`, `NLM`, `NML`. The algorithms that work faster than others are those that write data to the resulting matrix *row-wise in layers*: `LMN` and `MLN`, — the percentage difference to other algorithms is substantial and depends on the execution environment. *Further optimization: [Matrix multiplication in parallel streams]({{ '/en/2022/02/09/matrix-multiplication-parallel-streams.html' | relative_url }}).* ## Row-wise algorithm {#row-wise-algorithm} The outer loop bypasses the rows of the first matrix `L`, then there is a loop across the *common side* of the two matrices `M` and it is followed by a loop across the columns of the second matrix `N`. Writing to the resulting matrix occurs row-wise, and each row is filled in layers. ```java /** * @param l rows of matrix 'a' * @param m columns of matrix 'a' * and rows of matrix 'b' * @param n columns of matrix 'b' * @param a first matrix 'l×m' * @param b second matrix 'm×n' * @return resulting matrix 'l×n' */ public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) { // resulting matrix int[][] c = new int[l][n]; // bypass the indexes of the rows of matrix 'a' for (int i = 0; i < l; i++) // bypass the indexes of the common side of two matrices: // the columns of matrix 'a' and the rows of matrix 'b' for (int k = 0; k < m; k++) // bypass the indexes of the columns of matrix 'b' for (int j = 0; j < n; j++) // the sum of the products of the elements of the i-th // row of matrix 'a' and the j-th column of matrix 'b' c[i][j] += a[i][k] * b[k][j]; return c; } ``` ## Layer-wise algorithm {#layer-wise-algorithm} The outer loop bypasses the *common side* of the two matrices `M`, then there is a loop across the rows of the first matrix `L`, and it is followed by a loop across the columns of the second matrix `N`. Writing to the resulting matrix occurs layer-wise, and each layer is filled row-wise. ```java /** * @param l rows of matrix 'a' * @param m columns of matrix 'a' * and rows of matrix 'b' * @param n columns of matrix 'b' * @param a first matrix 'l×m' * @param b second matrix 'm×n' * @return resulting matrix 'l×n' */ public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) { // resulting matrix int[][] c = new int[l][n]; // bypass the indexes of the common side of two matrices: // the columns of matrix 'a' and the rows of matrix 'b' for (int k = 0; k < m; k++) // bypass the indexes of the rows of matrix 'a' for (int i = 0; i < l; i++) // bypass the indexes of the columns of matrix 'b' for (int j = 0; j < n; j++) // the sum of the products of the elements of the i-th // row of matrix 'a' and the j-th column of matrix 'b' c[i][j] += a[i][k] * b[k][j]; return c; } ``` ### Other algorithms {#other-algorithms} The bypass of the columns of the second matrix `N` occurs before the bypass of the *common side* of the two matrices `M` and/or before the bypass of the rows of the first matrix `L`. {% capture collapsed_md %} ```java public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int i = 0; i < l; i++) for (int j = 0; j < n; j++) for (int k = 0; k < m; k++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int j = 0; j < n; j++) for (int i = 0; i < l; i++) for (int k = 0; k < m; k++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int k = 0; k < m; k++) for (int j = 0; j < n; j++) for (int i = 0; i < l; i++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int j = 0; j < n; j++) for (int k = 0; k < m; k++) for (int i = 0; i < l; i++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` {% endcapture %} {%- include collapsed_block.html summary="Code without comments" content=collapsed_md -%} ## Comparing algorithms {#comparing-algorithms} To check, we take two matrices `A=[500×700]` and `B=[700×450]`, filled with random numbers. First, we compare the correctness of the implementation of the algorithms — all results obtained must match. Then we execute each method 10 times and calculate the average execution time. ```java // start the program and output the result public static void main(String[] args) throws Exception { // incoming data int l = 500, m = 700, n = 450, steps = 10; int[][] a = randomMatrix(l, m), b = randomMatrix(m, n); // map of methods for comparison var methods = new TreeMap>(Map.of( "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b), "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b), "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b), "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b), "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b), "NML", () -> matrixMultiplicationNML(l, m, n, a, b))); int[][] last = null; // bypass the methods map, check the correctness of the returned // results, all results obtained must be equal to each other for (var method : methods.entrySet()) { // next method for comparison var next = methods.higherEntry(method.getKey()); // if the current method is not the last — compare the results of two methods if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": " // compare the result of executing the current method and the next one + Arrays.deepEquals(method.getValue().call(), next.getValue().call())); // the result of the last method else last = method.getValue().call(); } int[][] test = last; // bypass the methods map, measure the execution time of each method for (var method : methods.entrySet()) // parameters: title, number of steps, runnable code benchmark(method.getKey(), steps, () -> { try { // execute the method, get the result int[][] result = method.getValue().call(); // check the correctness of the results at each step if (!Arrays.deepEquals(result, test)) System.out.print("error"); } catch (Exception e) { e.printStackTrace(); } }); } ``` {% capture collapsed_md %} ```java // helper method, returns a matrix of the specified size private static int[][] randomMatrix(int row, int col) { int[][] matrix = new int[row][col]; for (int i = 0; i < row; i++) for (int j = 0; j < col; j++) matrix[i][j] = (int) (Math.random() * row * col); return matrix; } ``` ```java // helper method for measuring the execution time of the passed code private static void benchmark(String title, int steps, Runnable runnable) { long time, avg = 0; System.out.print(title); for (int i = 0; i < steps; i++) { time = System.currentTimeMillis(); runnable.run(); time = System.currentTimeMillis() - time; // execution time of one step System.out.print(" | " + time); avg += time; } // average execution time System.out.println(" || " + (avg / steps)); } ``` {% endcapture %} {%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%} Output depends on the execution environment, time in milliseconds: ``` LMN=LNM: true LNM=MLN: true MLN=MNL: true MNL=NLM: true NLM=NML: true LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116 LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417 MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114 MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860 NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404 NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868 ``` All the methods described above, including collapsed blocks, can be placed in one class. {% capture collapsed_md %} ```java import java.util.Arrays; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.Callable; ``` {% endcapture %} {%- include collapsed_block.html summary="Required imports" content=collapsed_md -%}