--- title: Оптимизация умножения матриц description: Рассмотрим алгоритм перемножения матриц с использованием трёх вложенных циклов. Сложность такого алгоритма по определению должна составлять O(n³), но есть... sections: [Перестановки,Вложенные циклы,Сравнение алгоритмов] tags: [java,массивы,многомерные массивы,матрицы,строки,колонки,слои,циклы] canonical_url: /ru/2021/12/09/optimizing-matrix-multiplication.html url_translated: /en/2021/12/10/optimizing-matrix-multiplication.html title_translated: Optimizing matrix multiplication date: 2021.12.09 --- Рассмотрим алгоритм перемножения матриц с использованием трёх вложенных циклов. Сложность такого алгоритма по определению должна составлять `O(n³)`, но есть особенности, связанные со средой выполнения — скорость работы алгоритма зависит от последовательности, в которой выполняются циклы. Сравним различные варианты перестановок вложенных циклов и время выполнения алгоритмов. Возьмём две матрицы: {`L×M`} и {`M×N`} → три цикла → шесть перестановок: `LMN`, `LNM`, `MLN`, `MNL`, `NLM`, `NML`. Быстрее других отрабатывают те алгоритмы, которые пишут данные в результирующую матрицу *построчно слоями*: `LMN` и `MLN`, — разница в процентах к другим алгоритмам значительная и зависит от среды выполнения. *Дальнейшая оптимизация: [Умножение матриц в параллельных потоках]({{ '/ru/2022/02/08/matrix-multiplication-parallel-streams.html' | relative_url }}).* ## Построчный алгоритм {#row-wise-algorithm} Внешний цикл обходит строки первой матрицы `L`, далее идёт цикл по *общей стороне* двух матриц `M` и за ним цикл по колонкам второй матрицы `N`. Запись в результирующую матрицу происходит построчно, а каждая строка заполняется слоями. ```java /** * @param l строки матрицы 'a' * @param m колонки матрицы 'a' * и строки матрицы 'b' * @param n колонки матрицы 'b' * @param a первая матрица 'l×m' * @param b вторая матрица 'm×n' * @return результирующая матрица 'l×n' */ public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) { // результирующая матрица int[][] c = new int[l][n]; // обходим индексы строк матрицы 'a' for (int i = 0; i < l; i++) // обходим индексы общей стороны двух матриц: // колонок матрицы 'a' и строк матрицы 'b' for (int k = 0; k < m; k++) // обходим индексы колонок матрицы 'b' for (int j = 0; j < n; j++) // сумма произведений элементов i-ой строки // матрицы 'a' и j-ой колонки матрицы 'b' c[i][j] += a[i][k] * b[k][j]; return c; } ``` ## Послойный алгоритм {#layer-wise-algorithm} Внешний цикл обходит *общую сторону* двух матриц `M`, далее идёт цикл по строкам первой матрицы `L` и за ним цикл по колонкам второй матрицы `N`. Запись в результирующую матрицу происходит слоями, а каждый слой заполняется построчно. ```java /** * @param l строки матрицы 'a' * @param m колонки матрицы 'a' * и строки матрицы 'b' * @param n колонки матрицы 'b' * @param a первая матрица 'l×m' * @param b вторая матрица 'm×n' * @return результирующая матрица 'l×n' */ public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) { // результирующая матрица int[][] c = new int[l][n]; // обходим индексы общей стороны двух матриц: // колонок матрицы 'a' и строк матрицы 'b' for (int k = 0; k < m; k++) // обходим индексы строк матрицы 'a' for (int i = 0; i < l; i++) // обходим индексы колонок матрицы 'b' for (int j = 0; j < n; j++) // сумма произведений элементов i-ой строки // матрицы 'a' и j-ой колонки матрицы 'b' c[i][j] += a[i][k] * b[k][j]; return c; } ``` ### Прочие алгоритмы {#other-algorithms} Обход колонок второй матрицы `N` происходит перед обходом *общей стороны* двух матриц `M` и/или перед обходом строк первой матрицы `L`. {% capture collapsed_md %} ```java public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int i = 0; i < l; i++) for (int j = 0; j < n; j++) for (int k = 0; k < m; k++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int j = 0; j < n; j++) for (int i = 0; i < l; i++) for (int k = 0; k < m; k++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int k = 0; k < m; k++) for (int j = 0; j < n; j++) for (int i = 0; i < l; i++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` ```java public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) { int[][] c = new int[l][n]; for (int j = 0; j < n; j++) for (int k = 0; k < m; k++) for (int i = 0; i < l; i++) c[i][j] += a[i][k] * b[k][j]; return c; } ``` {% endcapture %} {%- include collapsed_block.html summary="Код без комментариев" content=collapsed_md -%} ## Сравнение алгоритмов {#comparing-algorithms} Для проверки возьмём две матрицы `A=[500×700]` и `B=[700×450]`, заполненные случайными числами. Сначала сравниваем между собой корректность реализации алгоритмов — все полученные результаты должны совпадать. Затем выполняем каждый метод по 10 раз и подсчитываем среднее время выполнения. ```java // запускаем программу и выводим результат public static void main(String[] args) throws Exception { // входящие данные int l = 500, m = 700, n = 450, steps = 10; int[][] a = randomMatrix(l, m), b = randomMatrix(m, n); // карта методов для сравнения var methods = new TreeMap>(Map.of( "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b), "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b), "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b), "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b), "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b), "NML", () -> matrixMultiplicationNML(l, m, n, a, b))); int[][] last = null; // обходим карту методов, проверяем корректность результатов, // все полученные результаты должны быть равны друг другу for (var method : methods.entrySet()) { // следующий метод для сравнения var next = methods.higherEntry(method.getKey()); // если текущий метод не последний — сравниваем результаты двух методов if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": " // сравниваем результат выполнения текущего метода и следующего за ним + Arrays.deepEquals(method.getValue().call(), next.getValue().call())); // результат выполнения последнего метода else last = method.getValue().call(); } int[][] test = last; // обходим карту методов, замеряем время работы каждого метода for (var method : methods.entrySet()) // параметры: заголовок, количество шагов, исполняемый код benchmark(method.getKey(), steps, () -> { try { // выполняем метод, получаем результат int[][] result = method.getValue().call(); // проверяем корректность результатов на каждом шаге if (!Arrays.deepEquals(result, test)) System.out.print("error"); } catch (Exception e) { e.printStackTrace(); } }); } ``` {% capture collapsed_md %} ```java // вспомогательный метод, возвращает матрицу указанного размера private static int[][] randomMatrix(int row, int col) { int[][] matrix = new int[row][col]; for (int i = 0; i < row; i++) for (int j = 0; j < col; j++) matrix[i][j] = (int) (Math.random() * row * col); return matrix; } ``` ```java // вспомогательный метод для замера времени работы переданного кода private static void benchmark(String title, int steps, Runnable runnable) { long time, avg = 0; System.out.print(title); for (int i = 0; i < steps; i++) { time = System.currentTimeMillis(); runnable.run(); time = System.currentTimeMillis() - time; // время выполнения одного шага System.out.print(" | " + time); avg += time; } // среднее время выполнения System.out.println(" || " + (avg / steps)); } ``` {% endcapture %} {%- include collapsed_block.html summary="Вспомогательные методы" content=collapsed_md -%} Вывод зависит от среды выполнения, время в миллисекундах: ``` LMN=LNM: true LNM=MLN: true MLN=MNL: true MNL=NLM: true NLM=NML: true LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116 LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417 MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114 MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860 NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404 NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868 ``` Все описанные выше методы, включая свёрнутые блоки, можно поместить в одном классе. {% capture collapsed_md %} ```java import java.util.Arrays; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.Callable; ``` {% endcapture %} {%- include collapsed_block.html summary="Необходимые импорты" content=collapsed_md -%}