Strassen's method

If you have seen matrices before, then you probably know how to multiply them. If A=(aij) and B=(bij) are square n×n matrices, then in the product C=A×B, we define the entry cij for i,j=1,2,...,n, by

cij=nk=0aik.bkj

You might at first think that any matrix multiplication algorithm must take Ω(n3) time, since the natural definition of matrix multiplication requires that many multiplications. You would be incorrect, however: we have a way to multiply matrices in less time than that. In this article, we shall see Strassen’s remarkable recursive algorithm for multiplying n×n matrices. It runs in Θ(nlog27) time, which is asymptotically better than the straightforward approach. [1]

 

Strassen’s method

The strassen’s method makes the recursion tree slightly less bushy by performing only 7 recursive multiplications of sub-matrices instead of 8. Strassen’s method is not at all obvious and it involves four steps as described in the book [1]. Let’s directly move on to the chapter end exercises of the book [1]. This article is just a supplementary, but not a substitute to the book. [1]

Exercises

4.2-2

Write pseudocode for Strassen’s algorithm.
By following the 4 steps mentioned above, we can write the algorithm. 

We assume that n is an exact power of 2. So, n=2k for some integer k.

**Input:** square matrices A and B  
**Output:** Product of two square matrices C  
STRASSENS_SQUARE_MATRIX_MULTIPLY(A,B)nA.rowsLetCbenewnnmatrix  
`if ` n=1  
++ C11A11B11  
`else `  
++ partition the matrices A in to A11,A12,A21,A22 and B into B11,B12,B21,B22  
++ partition the matrix C into C11,C12,C21 and C22
++  
++  Let S1 through S10 be submatrices such that,  
++  S1B12B22  
++  S2A11+A12  
++  S3A21+A22  
++  S4B21B11  
++  S5A11+A22  
++  S6B11+B22  
++  S7A12A22  
++  S8B21+B22  
++  S9A11A21  
++  S10B11+B12  
++  Let P1P7be n2n2 matrices      
++  P1STRASSENS_SQUARE_MATRIX_MULTIPLY(A11,S1)  
++  P2STRASSENS_SQUARE_MATRIX_MULTIPLY(S2,B22)  
++  P3STRASSENS_SQUARE_MATRIX_MULTIPLY(S3,B11)  
++  P4STRASSENS_SQUARE_MATRIX_MULTIPLY(A22,S4)   
++  P5STRASSENS_SQUARE_MATRIX_MULTIPLY(S5,S6)   
++  P6STRASSENS_SQUARE_MATRIX_MULTIPLY(S7,S8)   
++  P7STRASSENS_SQUARE_MATRIX_MULTIPLY(S9,S10)   
++  C11P5+P4P2+P6   
++  C12P1+P2   
++  C21P3+P4   
++  C22P5+P1P3P7   
`return ` C.

Now we have the algorithm with us and let’s implement it and use it to multiply two matrices first.

[1375].[6842]=[18146266]


Here’s the code:

public class Strassens {
private Strassens() {
throw new AssertionError();
}
public static void main(String[] args) {
int[][] a = { { 1, 3 }, { 7, 5 } };
int[][] b = { { 6, 8 }, { 4, 2 } };
int[][] c = squareMatrixMultiply(a, b);
System.out.println(Arrays.deepToString(c));
int[][] d = { { 3, 1, 1, 4 }, { 5, 3, 2, 1 } };
int[][] e = { { 4, 9 }, { 6, 8 }, { 9, 7 }, { 7, 6 } };
final int n = 2;
int[][] m1 = squareMatrixMultiply(d, 0, 0, e, 0, 0, n);
int[][] m2 = squareMatrixMultiply(d, 0, n, e, n, 0, n);
int[][] m3 = matrixSum(m1, m2, Integer::sum);
System.out.println(Arrays.deepToString(m3));
final int[][] f = { { 2, 7, 3 }, { 1, 5, 8 }, { 0, 4, 1 } };
final int[][] g = { { 3, 0, 1 }, { 2, 1, 0 }, { 1, 2, 4 } };
final int[][] m4 = squareMatrixMultiply(f, g);
System.out.println(Arrays.deepToString(m4));
}
public static int[][] squareMatrixMultiply(int[][] a, int[][] b) {
int n = a.length;
if (n == 0 || n != a[0].length || n != b.length || a[0].length != b[0].length)
throw new IllegalArgumentException(
"Not conformable for multiplication, different dimensions or empty matrices provided.");
if (isExactPowerOf2(n)) {
return squareMatrixMultiply(a, 0, 0, b, 0, 0, n);
} else {
final int m = (int) Math.pow(2, Math.ceil(base2Exponent(n)));
final int[][] paddedA = new int[m][m];
final int[][] paddedB = new int[m][m];
matrixCopy(a, paddedA, 0, 0);
matrixCopy(b, paddedB, 0, 0);
int[][] paddedC = squareMatrixMultiply(paddedA, 0, 0, paddedB, 0, 0, m);
final int[][] c = new int[n][n];
matrixCopy(paddedC, c, 0, 0, n);
return c;
}
}
private static int[][] squareMatrixMultiply(int[][] a, int startRowA, int startColA, int[][] b, int startRowB,
int startColB, int side) {
if (!isExactPowerOf2(side))
throw new IllegalArgumentException(String.format("n = %d should be an exact power of 2", side));
if (side == 1)
return new int[][] { { a[startRowA][startColA] * b[startRowB][startColB] } };
final int mid = side / 2;
final int[][] s1 = matrixSum(b, startRowB, startColB + mid, b, startRowB + mid, startColB + mid, mid,
(x, y) -> x - y);
final int[][] s2 = matrixSum(a, startRowA, startColA, a, startRowA, startColA + mid, mid, Integer::sum);
final int[][] s3 = matrixSum(a, startRowA + mid, startColA, a, startRowA + mid, startColA + mid, mid,
Integer::sum);
final int[][] s4 = matrixSum(b, startRowB + mid, startColB, b, startRowB, startColB, mid, (x, y) -> x - y);
final int[][] s5 = matrixSum(a, startRowA, startColA, a, startRowA + mid, startColA + mid, mid, Integer::sum);
final int[][] s6 = matrixSum(b, startRowB, startColB, b, startRowB + mid, startColB + mid, mid, Integer::sum);
final int[][] s7 = matrixSum(a, startRowA, startColA + mid, a, startRowA + mid, startColA + mid, mid,
(x, y) -> x - y);
final int[][] s8 = matrixSum(b, startRowB + mid, startColB, b, startRowB + mid, startColB + mid, mid,
Integer::sum);
final int[][] s9 = matrixSum(a, startRowA, startColA, a, startRowA + mid, startColA, mid, (x, y) -> x - y);
final int[][] s10 = matrixSum(b, startRowB, startColB, b, startRowB, startColB + mid, mid, Integer::sum);
final int[][] p1 = squareMatrixMultiply(a, startRowA, startColA, s1, 0, 0, mid);
final int[][] p2 = squareMatrixMultiply(s2, 0, 0, b, startRowB + mid, startColB + mid, mid);
final int[][] p3 = squareMatrixMultiply(s3, 0, 0, b, startRowB, startColB, mid);
final int[][] p4 = squareMatrixMultiply(a, startRowA + mid, startColA + mid, s4, 0, 0, mid);
final int[][] p5 = squareMatrixMultiply(s5, 0, 0, s6, 0, 0, mid);
final int[][] p6 = squareMatrixMultiply(s7, 0, 0, s8, 0, 0, mid);
final int[][] p7 = squareMatrixMultiply(s9, 0, 0, s10, 0, 0, mid);
final int[][] c11 = matrixSum(matrixSum(p4, p5, Integer::sum), matrixSum(p6, p2, (x, y) -> x - y),
Integer::sum);
final int[][] c12 = matrixSum(p1, p2, Integer::sum);
final int[][] c21 = matrixSum(p3, p4, Integer::sum);
final int[][] c22 = matrixSum(matrixSum(p5, p3, (x, y) -> x - y), matrixSum(p1, p7, (x, y) -> x - y),
Integer::sum);
final int[][] c = new int[side][side];
matrixCopy(c11, c, 0, 0);
matrixCopy(c12, c, 0, mid);
matrixCopy(c21, c, mid, 0);
matrixCopy(c22, c, mid, mid);
return c;
}
private static boolean isExactPowerOf2(int n) {
return base2Exponent(n) % 1 == 0;
}
private static double base2Exponent(int n) {
return Math.log(n) / Math.log(2);
}
private static int[][] matrixSum(int[][] a, int startRowA, int startColA, int[][] b, int startRowB, int startColB,
int side, IntBinaryOperator binOp) {
final int[][] c = new int[side][side];
for (int i = 0; i < side; i++)
for (int j = 0; j < side; j++)
c[i][j] = binOp.applyAsInt(a[startRowA + i][startColA + j], b[startRowB + i][startColB + j]);
return c;
}
private static int[][] matrixSum(int[][] a, int[][] b, IntBinaryOperator binOp) {
if (a.length != b.length && a[0].length != b[0].length)
throw new IllegalArgumentException("Not conformable for addition, different orders/dimensions.");
return matrixSum(a, 0, 0, b, 0, 0, a.length, binOp);
}
private static void matrixCopy(int[][] source, int[][] target, int startRow, int startCol) {
matrixCopy(source, target, startRow, startCol, source.length);
}
private static void matrixCopy(int[][] source, int[][] target, int startRow, int startCol, int side) {
if (side > target.length)
throw new IllegalArgumentException("Target matrix is too smaller than number of elements to be copied !");
for (int i = 0; i < side; i++)
for (int j = 0; j < side; j++)
target[startRow + i][startCol + j] = source[i][j];
}
}
view raw Strassens.java hosted with ❤ by GitHub


How about non-square matrices

Well, even though the Strassen’s method is defined only for square matrices, it can be used to multiply non-square matrices where circumstances warrant. [2] For an instance let’s consider two matrices,

[31145321].[49689776]=[55666389]

Write your left factor as a row of two square matrices and your right factor as a column of two square matrices each with the dimension 2×2. Then sum those two resulting matrices to get the answer. 

Exercises

4.2-3

How would you modify Strassen’s algorithm to multiply n×n matrices in which n is not an exact power of 2? Show that the resulting algorithm runs in time Θ(nlog27) [1]

Strassen’s algorithm can be applied to n×n matrix multiplications where n is not an exact power of 2 by padding the operands with 0’s. Let m=2k such that 2k1<n<2k (m equals 2log2n ). Create m×m matrices A’ and B’ by padding A and B respectively. Applying Strassen’s algorithm, the resulting matrices C’, A’ and B’ appear as follows, where C’ is the matrix product of A’ and B’:

C=[C000]A=[A000]B=[B000]

To obtain the product, we simply extract the matrix C from C.

The runtime for this method is Θ(mlog27). Since 2k1<n, it follows that m<2n. Therefore, the runtime becomes Θ((2n)log27)=Θ(2log27nlog27)=Θ(nlog27). [3]

Finally, let’s use our implementation of Strassen’s algorithm to multiply two 3×3 matrices.

[273158041].[301210124]=[231314212133964]

Analysis

Let’s use the Master method for analysing the recurrence,
T(n)=7T(n/2)+Θ(n2),

which describes the running time of Strassen’s algorithm. Here, we have a=7,b=2,f(n)=Θ(n2), and thus  nlogba=nlog27 Recalling that 2.8<log27<2.81 We see that f(n)=O(nlog27ϵ) for ϵ=0.8. Again, case 1 applies, and we have the solution T(n)=Θ(nlog27) [1]

References

[1] https://www.amazon.com/Introduction-Algorithms-3rd-MIT-Press/dp/0262033844
[2] https://math.stackexchange.com/questions/1445064/strassens-algorithm-for-non-square-matrices
[3] https://www.eecis.udel.edu/~saunders/courses/621/03f/modelV.pdf

Comments

Popular posts from this blog

Introducing Java Reactive Extentions in to a SpringBoot Micro Service

Combining the emissions of multiple Observables together using RxJava Zip operator in a Spring Boot Micro service

RabbitMQ Transport in WSO2 ESB