3/jekyll_site/en/2021/12/10/optimizing-matrix-multiplication.md
2023-12-17 07:56:19 +03:00

9.8 KiB
Raw Blame History

title description sections tags canonical_url url_translated title_translated date lang
Optimizing matrix multiplication Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be O(n³), but there are...
Permutations
Nested loops
Comparing algorithms
java
arrays
multidimensional arrays
matrices
rows
columns
layers
loops
/en/2021/12/10/optimizing-matrix-multiplication.html /ru/2021/12/09/optimizing-matrix-multiplication.html Оптимизация умножения матриц 2021.12.10 en

Consider an algorithm for multiplying matrices using three nested loops. The complexity of such an algorithm by definition should be O(n³), but there are particularities related to the execution environment — the speed of the algorithm depends on the sequence in which the loops are executed.

Let's compare different permutations of nested loops and the execution time of the algorithms. Let's take two matrices: {L×M} and {M×N} → three loops → six permutations: LMN, LNM, MLN, MNL, NLM, NML.

The algorithms that work faster than others are those that write data to the resulting matrix row-wise in layers: LMN and MLN, — the percentage difference to other algorithms is substantial and depends on the execution environment.

Further optimization: [Matrix multiplication in parallel streams]({{ '/en/2022/02/09/matrix-multiplication-parallel-streams.html' | relative_url }}).

Row-wise algorithm

The outer loop bypasses the rows of the first matrix L, then there is a loop across the common side of the two matrices M and it is followed by a loop across the columns of the second matrix N. Writing to the resulting matrix occurs row-wise, and each row is filled in layers.

/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the rows of matrix 'a'
    for (int i = 0; i < l; i++)
        // bypass the indexes of the common side of two matrices:
        // the columns of matrix 'a' and the rows of matrix 'b'
        for (int k = 0; k < m; k++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Layer-wise algorithm

The outer loop bypasses the common side of the two matrices M, then there is a loop across the rows of the first matrix L, and it is followed by a loop across the columns of the second matrix N. Writing to the resulting matrix occurs layer-wise, and each layer is filled row-wise.

/**
 * @param l rows of matrix 'a'
 * @param m columns of matrix 'a'
 *          and rows of matrix 'b'
 * @param n columns of matrix 'b'
 * @param a first matrix 'l×m'
 * @param b second matrix 'm×n'
 * @return resulting matrix 'l×n'
 */
public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) {
    // resulting matrix
    int[][] c = new int[l][n];
    // bypass the indexes of the common side of two matrices:
    // the columns of matrix 'a' and the rows of matrix 'b'
    for (int k = 0; k < m; k++)
        // bypass the indexes of the rows of matrix 'a'
        for (int i = 0; i < l; i++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

Other algorithms

The bypass of the columns of the second matrix N occurs before the bypass of the common side of the two matrices M and/or before the bypass of the rows of the first matrix L.

{% capture collapsed_md %}

public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int i = 0; i < l; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int i = 0; i < l; i++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int k = 0; k < m; k++)
        for (int j = 0; j < n; j++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int k = 0; k < m; k++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}

{% endcapture %} {%- include collapsed_block.html summary="Code without comments" content=collapsed_md -%}

Comparing algorithms

To check, we take two matrices A=[500×700] and B=[700×450], filled with random numbers. First, we compare the correctness of the implementation of the algorithms — all results obtained must match. Then we execute each method 10 times and calculate the average execution time.

// start the program and output the result
public static void main(String[] args) throws Exception {
    // incoming data
    int l = 500, m = 700, n = 450, steps = 10;
    int[][] a = randomMatrix(l, m), b = randomMatrix(m, n);
    // map of methods for comparison
    var methods = new TreeMap<String, Callable<int[][]>>(Map.of(
            "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b),
            "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b),
            "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b),
            "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b),
            "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b),
            "NML", () -> matrixMultiplicationNML(l, m, n, a, b)));
    int[][] last = null;
    // bypass the methods map, check the correctness of the returned
    // results, all results obtained must be equal to each other
    for (var method : methods.entrySet()) {
        // next method for comparison
        var next = methods.higherEntry(method.getKey());
        // if the current method is not the last — compare the results of two methods
        if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": "
                // compare the result of executing the current method and the next one
                + Arrays.deepEquals(method.getValue().call(), next.getValue().call()));
            // the result of the last method
        else last = method.getValue().call();
    }
    int[][] test = last;
    // bypass the methods map, measure the execution time of each method
    for (var method : methods.entrySet())
        // parameters: title, number of steps, runnable code
        benchmark(method.getKey(), steps, () -> {
            try { // execute the method, get the result
                int[][] result = method.getValue().call();
                // check the correctness of the results at each step
                if (!Arrays.deepEquals(result, test)) System.out.print("error");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
}

{% capture collapsed_md %}

// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // execution time of one step
        System.out.print(" | " + time);
        avg += time;
    }
    // average execution time
    System.out.println(" || " + (avg / steps));
}

{% endcapture %} {%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%}

Output depends on the execution environment, time in milliseconds:

LMN=LNM: true
LNM=MLN: true
MLN=MNL: true
MNL=NLM: true
NLM=NML: true
LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116
LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417
MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114
MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860
NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404
NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868

All the methods described above, including collapsed blocks, can be placed in one class.

{% capture collapsed_md %}

import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;

{% endcapture %} {%- include collapsed_block.html summary="Required imports" content=collapsed_md -%}