Strassen 矩阵乘法
- 朴素方法
- Strassen 矩阵乘法算法
- 示例
Strassen 矩阵乘法是一种分治法,用于解决矩阵乘法问题。传统的矩阵乘法方法将每一行与每一列相乘以得到乘积矩阵。这种方法的时间复杂度为 O(n3),因为需要两个循环进行乘法运算。Strassen 方法的引入将时间复杂度从 O(n3) 降低到 O(nlog 7)。
朴素方法
首先,我们讨论朴素方法及其复杂度。这里,我们计算 Z = X × Y。使用朴素方法,两个矩阵(X 和 Y)可以相乘,前提是它们的阶分别为 p × q 和 q × r,结果矩阵的阶为 p × r。以下伪代码描述了朴素乘法——
算法:Matrix-Multiplication (X, Y, Z)
for i = 1 to p do
for j = 1 to r do
Z[i,j] := 0
for k = 1 to q do
Z[i,j] := Z[i,j] + X[i,k] × Y[k,j]
复杂度
这里,我们假设整数运算耗时 O(1)。该算法中有三个 for 循环,其中一个嵌套在另一个内部。因此,该算法执行时间为 O(n3)。
Strassen 矩阵乘法算法
在这种情况下,使用 Strassen 矩阵乘法算法,可以略微改善时间消耗。
Strassen 矩阵乘法仅适用于方阵,且 n 是 2 的幂。两个矩阵的阶均为 n × n。
将 X、Y 和 Z 分为四个 (n/2)×(n/2) 子矩阵,如下所示——
$Z = \begin{bmatrix}I & J \\K & L \end{bmatrix}$ $X = \begin{bmatrix}A & B \\C & D \end{bmatrix}$ 和 $Y = \begin{bmatrix}E & F \\G & H \end{bmatrix}$
使用 Strassen 算法计算以下内容——
$$M_{1} \: \colon= (A+C) \times (E+F)$$
$$M_{2} \: \colon= (B+D) \times (G+H)$$
$$M_{3} \: \colon= (A-D) \times (E+H)$$
$$M_{4} \: \colon= A \times (F-H)$$
$$M_{5} \: \colon= (C+D) \times (E)$$
$$M_{6} \: \colon= (A+B) \times (H)$$
$$M_{7} \: \colon= D \times (G-E)$$
然后,
$$I \: \colon= M_{2} + M_{3} - M_{6} - M_{7}$$
$$J \: \colon= M_{4} + M_{6}$$
$$K \: \colon= M_{5} + M_{7}$$
$$L \: \colon= M_{1} - M_{3} - M_{4} - M_{5}$$
分析
$$T(n)=\begin{cases}c & if\:n= 1\\7\:x\:T(\frac{n}{2})+d\:x\:n^2 & otherwise\end{cases} \:where\: c\: and \:d\:are\: constants$$
根据这个递推关系,我们得到 $T(n) = O(n^{log7})$
因此,Strassen 矩阵乘法算法的复杂度为 $O(n^{log7})$。
示例
让我们来看看 Strassen's Matrix Multiplication 在各种编程语言中的实现:C、C++、Java、Python。
#include<stdio.h>
int main(){
int z[2][2];
int i, j;
int m1, m2, m3, m4 , m5, m6, m7;
int x[2][2] = {
{12, 34},
{22, 10}
};
int y[2][2] = {
{3, 4},
{2, 1}
};
printf("第一个矩阵是: ");
for(i = 0; i < 2; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", x[i][j]);
}
printf("\n第二个矩阵是: ");
for(i = 0; i < 2; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", y[i][j]);
}
m1= (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
m2= (x[1][0] + x[1][1]) * y[0][0];
m3= x[0][0] * (y[0][1] - y[1][1]);
m4= x[1][1] * (y[1][0] - y[0][0]);
m5= (x[0][0] + x[0][1]) * y[1][1];
m6= (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
m7= (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
z[0][0] = m1 + m4- m5 + m7;
z[0][1] = m3 + m5;
z[1][0] = m2 + m4;
z[1][1] = m1 - m2 + m3 + m6;
printf("\n使用 Strassen 算法得到的乘积: ");
for(i = 0; i < 2 ; i++) {
printf("\n");
for(j = 0; j < 2; j++)
printf("%d\t", z[i][j]);
}
return 0;
}
输出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
#include<iostream>
using namespace std;
int main() {
int z[2][2];
int i, j;
int m1, m2, m3, m4 , m5, m6, m7;
int x[2][2] = {
{12, 34},
{22, 10}
};
int y[2][2] = {
{3, 4},
{2, 1}
};
cout<<"第一个矩阵是: ";
for(i = 0; i < 2; i++) {
cout<<endl;
for(j = 0; j < 2; j++)
cout<<x[i][j]<<" ";
}
cout<<"\n第二个矩阵是: ";
for(i = 0;i < 2; i++){
cout<<endl;
for(j = 0;j < 2; j++)
cout<<y[i][j]<<" ";
}
m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
m2 = (x[1][0] + x[1][1]) * y[0][0];
m3 = x[0][0] * (y[0][1] - y[1][1]);
m4 = x[1][1] * (y[1][0] - y[0][0]);
m5 = (x[0][0] + x[0][1]) * y[1][1];
m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
z[0][0] = m1 + m4- m5 + m7;
z[0][1] = m3 + m5;
z[1][0] = m2 + m4;
z[1][1] = m1 - m2 + m3 + m6;
cout<<"\n使用 Strassen 算法得到的乘积: ";
for(i = 0; i < 2 ; i++) {
cout<<endl;
for(j = 0; j < 2; j++)
cout<<z[i][j]<<" ";
}
return 0;
}
输出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
public class Strassens {
public static void main(String[] args) {
int[][] x = {{12, 34}, {22, 10}};
int[][] y = {{3, 4}, {2, 1}};
int z[][] = new int[2][2];
int m1, m2, m3, m4 , m5, m6, m7;
System.out.print("第一个矩阵是: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(x[i][j] + "\t");
}
}
System.out.print("\n第二个矩阵是: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(y[i][j] + "\t");
}
}
m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
m2 = (x[1][0] + x[1][1]) * y[0][0];
m3 = x[0][0] * (y[0][1] - y[1][1]);
m4 = x[1][1] * (y[1][0] - y[0][0]);
m5 = (x[0][0] + x[0][1]) * y[1][1];
m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
z[0][0] = m1 + m4- m5 + m7;
z[0][1] = m3 + m5;
z[1][0] = m2 + m4;
z[1][1] = m1 - m2 + m3 + m6;
System.out.print("\n使用 Strassen 算法得到的乘积: ");
for(int i = 0; i<2; i++) {
System.out.println();//new line
for(int j = 0; j<2; j++) {
System.out.print(z[i][j] + "\t");
}
}
}
}
输出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
import numpy as np
x = np.array([[12, 34], [22, 10]])
y = np.array([[3, 4], [2, 1]])
z = np.zeros((2, 2))
m1, m2, m3, m4, m5, m6, m7 = 0, 0, 0, 0, 0, 0, 0
print("第一个矩阵是: ")
for i in range(2):
print()
for j in range(2):
print(x[i][j], end="\t")
print("\n第二个矩阵是: ")
for i in range(2):
print()
for j in range(2):
print(y[i][j], end="\t")
m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1])
m2 = (x[1][0] + x[1][1]) * y[0][0]
m3 = x[0][0] * (y[0][1] - y[1][1])
m4 = x[1][1] * (y[1][0] - y[0][0])
m5 = (x[0][0] + x[0][1]) * y[1][1]
m6 = (x[1][0] - x[0][0]) * (y[0][0] + y[0][1])
m7 = (x[0][1] - x[1][1]) * (y[1][0] + y[1][1])
z[0][0] = m1 + m4 - m5 + m7
z[0][1] = m3 + m5
z[1][0] = m2 + m4
z[1][1] = m1 - m2 + m3 + m6
print("\n使用 Strassen 算法得到的乘积: ")
for i in range(2):
print()
for j in range(2):
print(z[i][j], end="\t")
输出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104.0 82.0 86.0 98.0