本例輸入為兩個任意尺寸的矩陣m * n, n * m,輸出為兩個矩陣的乘積。計算任意尺寸矩陣相乘時,使用了Strassen算法。程序為自編,經過測試,請放心使用。基本算法是:
1.對于方陣(正方形矩陣),找到最大的l, 使得l = 2 ^ k, k為整數并且l < m。邊長為l的方形矩陣則采用Strassen算法,其余部分以及方形矩陣中遺漏的部分用蠻力法。
2.對于非方陣,依照行列相應添加0使其成為方陣。
StrassenMethodTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
|
package matrixalgorithm; import java.util.Scanner; public class StrassenMethodTest { private StrassenMethod strassenMultiply; StrassenMethodTest(){ strassenMultiply = new StrassenMethod(); } //end cons public static void main(String[] args){ Scanner input = new Scanner(System.in); System.out.println( "Input row size of the first matrix: " ); int arow = input.nextInt(); System.out.println( "Input column size of the first matrix: " ); int acol = input.nextInt(); System.out.println( "Input row size of the second matrix: " ); int brow = input.nextInt(); System.out.println( "Input column size of the second matrix: " ); int bcol = input.nextInt(); double [][] A = new double [arow][acol]; double [][] B = new double [brow][bcol]; double [][] C = new double [arow][bcol]; System.out.println( "Input data for matrix A: " ); /*In all of the codes later in this project, r means row while c means column. */ for (int r = 0; r < arow; r++) { for (int c = 0; c < acol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < brow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethodTest algorithm = new StrassenMethodTest(); C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol); //Display the calculation result: System.out.println("Result from matrix C: "); for (int r = 0; r < arow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main //Deal with matrices that are not square: public double[][] multiplyRectMatrix(double[][] A, double[][] B, int arow, int acol, int brow, int bcol) { if (arow != bcol) //Invalid multiplicatio return new double[][]{{0}}; double[][] C = new double[arow][bcol]; if (arow < acol) { double[][] newA = new double[acol][acol]; double[][] newB = new double[brow][brow]; int n = acol; for (int r = 0; r < acol; r++) for (int c = 0; c < acol; c++) newA[r][c] = 0.0; for (int r = 0; r < brow; r++) for (int c = 0; c < brow; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end if else if(arow == acol) C = multiplySquareMatrix(A, B, arow); else { int n = arow; double[][] newA = new double[arow][arow]; double[][] newB = new double[bcol][bcol]; for (int r = 0; r < arow; r++) for (int c = 0; c < arow; c++) newA[r][c] = 0.0; for (int r = 0; r < bcol; r++) for (int c = 0; c < bcol; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end else return C; }//end method //Deal with matrices that are square matrices. public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){ double[][] C2 = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C2[r][c] = 0; if(n == 1){ C2[0][0] = A2[0][0] * B2[0][0]; return C2; }//end if int exp2k = 2; while(exp2k <= (n / 2) ){ exp2k *= 2; }//end loop if(exp2k == n){ C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n); return C2; }//end else //The "biggest" strassen matrix: double[][][] A = new double[6][exp2k][exp2k]; double[][][] B = new double[6][exp2k][exp2k]; double[][][] C = new double[6][exp2k][exp2k]; for(int r = 0; r < exp2k; r++){ for(int c = 0; c < exp2k; c++){ A[0][r][c] = A2[r][c]; B[0][r][c] = B2[r][c]; }//end inner loop }//end outter loop C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k); for(int r = 0; r < exp2k; r++) for(int c = 0; c < exp2k; c++) C2[r][c] = C[0][r][c]; int middle = exp2k / 2; for(int r = 0; r < middle; r++){ for(int c = exp2k; c < n; c++){ A[1][r][c - exp2k] = A2[r][c]; B[3][r][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = 0; c < middle; c++){ A[3][r - exp2k][c] = A2[r][c]; B[1][r - exp2k][c] = B2[r][c]; }//end inner loop }//end outter loop for(int r = middle; r < exp2k; r++){ for(int c = exp2k; c < n; c++){ A[2][r - middle][c - exp2k] = A2[r][c]; B[4][r - middle][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = middle; c < n - exp2k + 1; c++){ A[4][r - exp2k][c - middle] = A2[r][c]; B[2][r - exp2k][c - middle] = B2[r][c]; }//end inner loop }//end outter loop for(int i = 1; i <= 4; i++) C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle); /* Calculate the final results of grids in the "biggest 2^k square, according to the rules of matrice multiplication. */ for ( int row = 0 ; row < exp2k; row++) { for ( int col = 0 ; col < exp2k; col++) { for ( int k = exp2k; k < n; k++) { C2[row][col] += A2[row][k] * B2[k][col]; } //end loop } //end inner loop } //end outter loop //Use brute force to solve the rest, will be improved later: for ( int col = exp2k; col < n; col++){ for ( int row = 0 ; row < n; row++){ for ( int k = 0 ; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; } //end inner loop } //end outter loop for ( int row = exp2k; row < n; row++){ for ( int col = 0 ; col < exp2k; col++){ for ( int k = 0 ; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; } //end inner loop } //end outter loop return C2; } //end method } //end class |
StrassenMethod.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
package matrixalgorithm; import java.util.Scanner; public class StrassenMethod { private double [][][][] A = new double [ 2 ][ 2 ][][]; private double [][][][] B = new double [ 2 ][ 2 ][][]; private double [][][][] C = new double [ 2 ][ 2 ][][]; /*//Codes for testing this class: public static void main(String[] args) { Scanner input = new Scanner(System.in); System.out.println("Input size of the matrix: "); int n = input.nextInt(); double[][] A = new double[n][n]; double[][] B = new double[n][n]; double[][] C = new double[n][n]; System.out.println("Input data for matrix A: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethod algorithm = new StrassenMethod(); C = algorithm.strassenMultiplyMatrix(A, B, n); System.out.println("Result from matrix C: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main*/ public double [][] strassenMultiplyMatrix( double [][] A2, double B2[][], int n){ double [][] C2 = new double [n][n]; //Initialize the matrix: for ( int rowIndex = 0 ; rowIndex < n; rowIndex++) for ( int colIndex = 0 ; colIndex < n; colIndex++) C2[rowIndex][colIndex] = 0.0 ; if (n == 1 ) C2[ 0 ][ 0 ] = A2[ 0 ][ 0 ] * B2[ 0 ][ 0 ]; //"Slice matrices into 2 * 2 parts: else { double [][][][] A = new double [ 2 ][ 2 ][n / 2 ][n / 2 ]; double [][][][] B = new double [ 2 ][ 2 ][n / 2 ][n / 2 ]; double [][][][] C = new double [ 2 ][ 2 ][n / 2 ][n / 2 ]; for ( int r = 0 ; r < n / 2 ; r++){ for ( int c = 0 ; c < n / 2 ; c++){ A[ 0 ][ 0 ][r][c] = A2[r][c]; A[ 0 ][ 1 ][r][c] = A2[r][n / 2 + c]; A[ 1 ][ 0 ][r][c] = A2[n / 2 + r][c]; A[ 1 ][ 1 ][r][c] = A2[n / 2 + r][n / 2 + c]; B[ 0 ][ 0 ][r][c] = B2[r][c]; B[ 0 ][ 1 ][r][c] = B2[r][n / 2 + c]; B[ 1 ][ 0 ][r][c] = B2[n / 2 + r][c]; B[ 1 ][ 1 ][r][c] = B2[n / 2 + r][n / 2 + c]; } //end loop } //end loop n = n / 2 ; double [][][] S = new double [ 10 ][n][n]; S[ 0 ] = minusMatrix(B[ 0 ][ 1 ], B[ 1 ][ 1 ], n); S[ 1 ] = addMatrix(A[ 0 ][ 0 ], A[ 0 ][ 1 ], n); S[ 2 ] = addMatrix(A[ 1 ][ 0 ], A[ 1 ][ 1 ], n); S[ 3 ] = minusMatrix(B[ 1 ][ 0 ], B[ 0 ][ 0 ], n); S[ 4 ] = addMatrix(A[ 0 ][ 0 ], A[ 1 ][ 1 ], n); S[ 5 ] = addMatrix(B[ 0 ][ 0 ], B[ 1 ][ 1 ], n); S[ 6 ] = minusMatrix(A[ 0 ][ 1 ], A[ 1 ][ 1 ], n); S[ 7 ] = addMatrix(B[ 1 ][ 0 ], B[ 1 ][ 1 ], n); S[ 8 ] = minusMatrix(A[ 0 ][ 0 ], A[ 1 ][ 0 ], n); S[ 9 ] = addMatrix(B[ 0 ][ 0 ], B[ 0 ][ 1 ], n); double [][][] P = new double [ 7 ][n][n]; P[ 0 ] = strassenMultiplyMatrix(A[ 0 ][ 0 ], S[ 0 ], n); P[ 1 ] = strassenMultiplyMatrix(S[ 1 ], B[ 1 ][ 1 ], n); P[ 2 ] = strassenMultiplyMatrix(S[ 2 ], B[ 0 ][ 0 ], n); P[ 3 ] = strassenMultiplyMatrix(A[ 1 ][ 1 ], S[ 3 ], n); P[ 4 ] = strassenMultiplyMatrix(S[ 4 ], S[ 5 ], n); P[ 5 ] = strassenMultiplyMatrix(S[ 6 ], S[ 7 ], n); P[ 6 ] = strassenMultiplyMatrix(S[ 8 ], S[ 9 ], n); C[ 0 ][ 0 ] = addMatrix(minusMatrix(addMatrix(P[ 4 ], P[ 3 ], n), P[ 1 ], n), P[ 5 ], n); C[ 0 ][ 1 ] = addMatrix(P[ 0 ], P[ 1 ], n); C[ 1 ][ 0 ] = addMatrix(P[ 2 ], P[ 3 ], n); C[ 1 ][ 1 ] = minusMatrix(minusMatrix(addMatrix(P[ 4 ], P[ 0 ], n), P[ 2 ], n), P[ 6 ], n); n *= 2 ; for ( int r = 0 ; r < n / 2 ; r++){ for ( int c = 0 ; c < n / 2 ; c++){ C2[r][c] = C[ 0 ][ 0 ][r][c]; C2[r][n / 2 + c] = C[ 0 ][ 1 ][r][c]; C2[n / 2 + r][c] = C[ 1 ][ 0 ][r][c]; C2[n / 2 + r][n / 2 + c] = C[ 1 ][ 1 ][r][c]; } //end inner loop } //end outter loop } //end else return C2; } //end method //Add two matrices according to matrix addition. private double [][] addMatrix( double [][] A, double [][] B, int n){ double C[][] = new double [n][n]; for ( int r = 0 ; r < n; r++) for ( int c = 0 ; c < n; c++) C[r][c] = A[r][c] + B[r][c]; return C; } //end method //Substract two matrices according to matrix addition. private double [][] minusMatrix( double [][] A, double [][] B, int n){ double C[][] = new double [n][n]; for ( int r = 0 ; r < n; r++) for ( int c = 0 ; c < n; c++) C[r][c] = A[r][c] - B[r][c]; return C; } //end method } //end class |
希望本文所述對大家學習java程序設計有所幫助。