Strassen矩阵乘法
矩阵乘法是线性代数中最常见的运算之一,它在数值计算中有广泛的应用。若A和B是2个n×n的矩阵,则它们的乘积C=AB同样是一个n×n的矩阵。A和B的乘积矩阵C中的元素C[i,j]定义为:
若依此定义来计算A和B的乘积矩阵C,则每计算C的一个元素C[i,j],需要做n个乘法和n-1次加法。因此,求出矩阵C的n2个元素所需的计算时间为0(n3)。
60年代末,Strassen采用了类似于在大整数乘法中用过的分治技术,将计算2个n阶矩阵乘积所需的计算时间改进到O(nlog7)=O(n2.18)。
首先,我们还是需要假设n是2的幂。将矩阵A,B和C中每一矩阵都分块成为4个大小相等的子矩阵,每个子矩阵都是n/2×n/2的方阵。由此可将方程C=AB重写为:
由此可得:
C11=A11B11+A12B21 (2)
C12=A11B12+A12B22 (3)
C21=A21B11+A22B21 (4)
C22=A21B12+A22B22 (5)
如果n=2,则2个2阶方阵的乘积可以直接用(2)-(3)式计算出来,共需8次乘法和4次加法。当子矩阵的阶大于2时,为求2个子矩阵的积,可以继续将子矩阵分块,直到子矩阵的阶降为2。这样,就产生了一个分治降阶的递归算法。依此算法,计算2个n阶方阵的乘积转化为计算8个n/2阶方阵的乘积和4个n/2阶方阵的加法。2个n/2×n/2矩阵的加法显然可以在c*n2/4时间内完成,这里c是一个常数。因此,上述分治法的计算时间耗费T(n)应该满足:
这个递归方程的解仍然是T(n)=O(n3)。因此,该方法并不比用原始定义直接计算更有效。究其原因,乃是由于式(2)-(5)并没有减少矩阵的乘法次数。而矩阵乘法耗费的时间要比矩阵加减法耗费的时间多得多。要想改进矩阵乘法的计算时间复杂性,必须减少子矩阵乘法运算的次数。按照上述分治法的思想可以看出,要想减少乘法运算次数,关键在于计算2个2阶方阵的乘积时,能否用少于8次的乘法运算。Strassen提出了一种新的算法来计算2个2阶方阵的乘积。他的算法只用了7次乘法运算,但增加了加、减法的运算次数。这7次乘法是:
M1=A11(B12-B22)
M2=(A11+A12)B22
M3=(A21+A22)B11
M4=A22(B21-B11)
M5=(A11+A22)(B11+B22)
M6=(A12-A22)(B21+B22)
M7=(A11-A21)(B11+B12)
做了这7次乘法后,再做若干次加、减法就可以得到:
C11=M5+M4-M2+M6
C12=M1+M2
C21=M3+M4
C22=M5+M1-M3-M7
以上计算的正确性很容易验证。例如:
C22=M5+M1-M3-M7
=(A11+A22)(B11+B22)+A11(B12-B22)-(A21+A22)B11-(A11-A21)(B11+B12)
=A11B11+A11B22+A22B11+A22B22+A11B12
-A11B22-A21B11-A22B11-A11B11-A11B12+A21B11+A21B12
=A21B12+A22B22
由(2)式便知其正确性。
至此,我们可以得到完整的Strassen算法如下:
procedure STRASSEN(n,A,B,C);
begin
if n=2 then MATRIX-MULTIPLY(A,B,C)
else begin
将矩阵A和B依(1)式分块;
STRASSEN(n/2,A11,B12-B22,M1);
STRASSEN(n/2,A11+A12,B22,M2);
STRASSEN(n/2,A21+A22,B11,M3);
STRASSEN(n/2,A22,B21-B11,M4);
STRASSEN(n/2,A11+A22,B11+B22,M5);
STRASSEN(n/2,A12-A22,B21+B22,M6);
STRASSEN(n/2,A11-A21,B11+B12,M7);; end; end;
其中MATRIX-MULTIPLY(A,B,C)是按通常的矩阵乘法计算C=AB的子算法。
Strassen矩阵乘积分治算法中,用了7次对于n/2阶矩阵乘积的递归调用和18次n/2阶矩阵的加减运算。由此可知,该算法的所需的计算时间T(n)满足如下的递归方程:
按照解递归方程的套用公式法,其解为T(n)=O(nlog7)≈O(n2.81)。由此可见,Strassen矩阵乘法的计算时间复杂性比普通矩阵乘法有阶的改进。
有人曾列举了计算2个2阶矩阵乘法的36种不同方法。但所有的方法都要做7次乘法。除非能找到一种计算2阶方阵乘积的算法,使乘法的计算次数少于7次,按上述思路才有可能进一步改进矩阵乘积的计算时间的上界。但是Hopcroft和Kerr(197l)已经证明,计算2个2×2矩阵的乘积,7次乘法是必要的。因此,要想进一步改进矩阵乘法的时间复杂性,就不能再寄希望于计算2×2矩阵的乘法次数的减少。或许应当研究3×3或5×5矩阵的更好算法。在Strassen之后又有许多算法改进了矩阵乘法的计算时间复杂性。目前最好的计算时间上界是O(n2.367)。而目前所知道的矩阵乘法的最好下界仍是它的平凡下界Ω(n2)。因此到目前为止还无法确切知道矩阵乘法的时间复杂性。关于这一研究课题还有许多工作可做。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191 |
/**
* Strassen矩阵乘法
* */ import java.util.*;
public class Strassen{
public Strassen(){
A = new int [NUMBER][NUMBER];
B = new int [NUMBER][NUMBER];
C = new int [NUMBER][NUMBER];
}
/**
* 输入矩阵函数
* */
public void input( int a[][]){
Scanner scanner = new Scanner(System.in);
for ( int i = 0 ; i < a.length; i++) {
for ( int j = 0 ; j < a[i].length; j++) {
a[i][j] = scanner.nextInt();
}
}
}
/**
* 输出矩阵
* */
public void output( int [][] resault){
for ( int b[] : resault) {
for ( int temp : b) {
System.out.print(temp + " " );
}
System.out.println();
}
}
/**
* 矩阵乘法,此处只是定义了2*2矩阵的乘法
* */
public void Mul( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < 2 ; ++i) {
for ( int j = 0 ; j < 2 ; ++j) {
resault[i][j] = 0 ;
for ( int k = 0 ; k < 2 ; ++k) {
resault[i][j] += first[i][k] * second[k][j];
}
}
}
}
/**
* 矩阵的加法运算
* */
public void Add( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < first.length; i++) {
for ( int j = 0 ; j < first[i].length; j++) {
resault[i][j] = first[i][j] + second[i][j];
}
}
}
/**
* 矩阵的减法运算
* */
public void sub( int [][] first, int [][] second, int [][] resault){
for ( int i = 0 ; i < first.length; i++) {
for ( int j = 0 ; j < first[i].length; j++) {
resault[i][j] = first[i][j] - second[i][j];
}
}
}
/**
* strassen矩阵算法
* */
public void strassen( int [][] A, int [][] B, int [][] C){
//定义一些中间变量
int [][] M1= new int [NUMBER][NUMBER];
int [][] M2= new int [NUMBER][NUMBER];
int [][] M3= new int [NUMBER][NUMBER];
int [][] M4= new int [NUMBER][NUMBER];
int [][] M5= new int [NUMBER][NUMBER];
int [][] M6= new int [NUMBER][NUMBER];
int [][] M7= new int [NUMBER][NUMBER];
int [][] C11= new int [NUMBER][NUMBER];
int [][] C12= new int [NUMBER][NUMBER];
int [][] C21= new int [NUMBER][NUMBER];
int [][] C22= new int [NUMBER][NUMBER];
int [][] A11= new int [NUMBER][NUMBER];
int [][] A12= new int [NUMBER][NUMBER];
int [][] A21= new int [NUMBER][NUMBER];
int [][] A22= new int [NUMBER][NUMBER];
int [][] B11= new int [NUMBER][NUMBER];
int [][] B12= new int [NUMBER][NUMBER];
int [][] B21= new int [NUMBER][NUMBER];
int [][] B22= new int [NUMBER][NUMBER];
int [][] temp= new int [NUMBER][NUMBER];
int [][] temp1= new int [NUMBER][NUMBER];
if (A.length== 2 ){
Mul(A, B, C);
} else {
//首先将矩阵A,B 分为4块
for ( int i = 0 ; i < A.length/ 2 ; i++) {
for ( int j = 0 ; j < A.length/ 2 ; j++) {
A11[i][j]=A[i][j];
A12[i][j]=A[i][j+A.length/ 2 ];
A21[i][j]=A[i+A.length/ 2 ][j];
A22[i][j]=A[i+A.length/ 2 ][j+A.length/ 2 ];
B11[i][j]=B[i][j];
B12[i][j]=B[i][j+A.length/ 2 ];
B21[i][j]=B[i+A.length/ 2 ][j];
B22[i][j]=B[i+A.length/ 2 ][j+A.length/ 2 ];
}
}
//计算M1
sub(B12, B22, temp);
Mul(A11, temp, M1);
//计算M2
Add(A11, A12, temp);
Mul(temp, B22, M2);
//计算M3
Add(A21, A22, temp);
Mul(temp, B11, M3);
//M4
sub(B21, B11, temp);
Mul(A22, temp, M4);
//M5
Add(A11, A22, temp1);
Add(B11, B22, temp);
Mul(temp1, temp, M5);
//M6
sub(A12, A22, temp1);
Add(B21, B22, temp);
Mul(temp1, temp, M6);
//M7
sub(A11, A21, temp1);
Add(B11, B12, temp);
Mul(temp1, temp, M7);
//计算C11
Add(M5, M4, temp1);
sub(temp1, M2, temp);
Add(temp, M6, C11);
//计算C12
Add(M1, M2, C12);
//C21
Add(M3, M4, C21);
//C22
Add(M5, M1, temp1);
sub(temp1, M3, temp);
sub(temp, M7, C22);
//结果送回C中
for ( int i = 0 ; i < C.length/ 2 ; i++) {
for ( int j = 0 ; j < C.length/ 2 ; j++) {
C[i][j]=C11[i][j];
C[i][j+C.length/ 2 ]=C12[i][j];
C[i+C.length/ 2 ][j]=C21[i][j];
C[i+C.length/ 2 ][j+C.length/ 2 ]=C22[i][j];
}
}
}
}
public static void main(String[] args){
Strassen demo= new Strassen();
System.out.println( "输入矩阵A" );
demo.input(A);
System.out.println( "输入矩阵B" );
demo.input(B);
demo.strassen(A, B, C);
demo.output(C);
}
private static int A[][];
private static int B[][];
private static int C[][];
private final static int NUMBER = 4 ; } |
【测试】:
1 1 1 1
1 1 1 1
1 1 1 1
1 1 1 1
-----------
2 2 2 2
2 2 2 2
2 2 2 2
2 2 2 2
--------
8 8 8 8
8 8 8 8
8 8 8 8
8 8 8 8