Open-Source Internship opportunity by OpenGenus for programmers. Apply now.
Reading time: 25 minutes | Coding time: 10 minutes
Matrix Multiplication is one of the most fundamental operation in Machine Learning and optimizing it is the key to several optimizations. In general, multipling two matrices of size N X N takes N^3 operations. Since then, we have come a long way to better and clever matrix multiplication algorithms. Volker Strassen first published his algorithm in 1969. It was the first algorithm to prove that the basic O(n^3) runtime was not optiomal.
The basic idea behind Strassen's algorithm is to split A & B into 8 submatricies and then recursively compute the submatricies of C. This strategy is called Divide and Conquer.
Consider the following matrices A and B:
matrix A = |a b|, matrix B = |e f|
|c d| |g h|
There will be 8 recursive calls:
a * e
b * g
a * f
b * h
c * e
d * g
c * f
d * h
We then use these results to compute C's submatricies.
matrix C = |ae+bg af+bh|
|ce+dg cf+dh|
The above strategy is the basic O(N^3) strategy.
Using the Master Theorem with T(n) = 8T(n/2) + O(n^2) we still get a runtime of O(n^3).
Strassen's insight was that we don't actually need 8 recursive calls to complete this process. We can finish the call with 7 recursive calls and a little bit of addition and subtraction.
Strassen's 7 calls are as follows:
a * (f - h)
(a + b) * h
(c + d) * e
d * (g - e)
(a + d) * (e + h)
(b - d) * (g + h)
(a - c) * (e + f)
Our new matrix C's new quardents
matrix C = |p5+p4-p2+p6 p1+p2 |
| p3+p4 p1+p5-p3-p7|
Strassen's Submatrix:
p5+p4-p2+p6 = (a+d)*(e+h) + d*(g-e) - (a+b)*h + (b-d)*(g+h)
= (ae+de+ah+dh) + (dg-de) - (ah+bh) + (bg-dg+bh-dh)
= ae+bg
p1+p2 = a*(f-h) + (a+b)*h
= (af-ah) + (ah+bh)
= af+bh
p3+p4 = (c+d)*e + d*(g-e)
= (ce+de) + (dg-de)
= ce+dg
p1+p5-p3-p7 = a*(f-h) + (a+d)*(e+h) - (c+d)*e - (a-c)*(e+f)
= (af-ah) + (ae+de+ah+dh) -(ce+de) - (ae-ce+af-cf)
= cf+dh
The time completxity using the Master Theorem.
T(n) = 7T(n/2) + O(n^2) which leads to O(n^log(7)) runtime. This comes out to approxiamtely O(n^2.8074) which is better than O(n^3)
Pseudocode
- Divide matrices A and B in 4 sub-matrices of size N/2 x N/2 as shown in the above diagram.
- Calculate the 7 matrix multiplications recursively.
- Compute the submatricies of C.
- Combine these submatricies into our new matrix C
Complexity
- Worst case time complexity:
Θ(n^2.8074)
- Best case time complexity:
Θ(1)
- Space complexity:
Θ(logn)
Implementations
- C++
C++
#include < iostream>
#include < vector>
#include < cmath>
#include < algorithm>
using namespace std;
int nextpowerof2(int k)
{
return pow(2, int(ceil(log2(k))));
}
void display(vector< vector<int>> &matrix, int m, int n)
{
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
if (j != 0)
{
cout << "\t";
}
cout << matrix[i][j];
}
cout << endl;
}
}
void add(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &C, int size)
{
int i, j;
for (i = 0; i < size; i++)
{
for (j = 0; j < size; j++)
{
C[i][j] = A[i][j] + B[i][j];
}
}
}
void sub(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &C, int size)
{
int i, j;
for (i = 0; i < size; i++)
{
for (j = 0; j < size; j++)
{
C[i][j] = A[i][j] - B[i][j];
}
}
}
void STRASSEN_algorithmA(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &C, int size)
{
//base case
if (size == 1)
{
C[0][0] = A[0][0] * B[0][0];
return;
}
else
{
int new_size = size / 2;
vector<int> z(new_size);
vector<vector<int>>
a11(new_size, z), a12(new_size, z), a21(new_size, z), a22(new_size, z),
b11(new_size, z), b12(new_size, z), b21(new_size, z), b22(new_size, z),
c11(new_size, z), c12(new_size, z), c21(new_size, z), c22(new_size, z),
p1(new_size, z), p2(new_size, z), p3(new_size, z), p4(new_size, z),
p5(new_size, z), p6(new_size, z), p7(new_size, z),
aResult(new_size, z), bResult(new_size, z);
int i, j;
//dividing the matrices into sub-matrices:
for (i = 0; i < new_size; i++)
{
for (j = 0; j < new_size; j++)
{
a11[i][j] = A[i][j];
a12[i][j] = A[i][j + new_size];
a21[i][j] = A[i + new_size][j];
a22[i][j] = A[i + new_size][j + new_size];
b11[i][j] = B[i][j];
b12[i][j] = B[i][j + new_size];
b21[i][j] = B[i + new_size][j];
b22[i][j] = B[i + new_size][j + new_size];
}
}
// Calculating p1 to p7:
add(a11, a22, aResult, new_size); // a11 + a22
add(b11, b22, bResult, new_size); // b11 + b22
STRASSEN_algorithmA(aResult, bResult, p1, new_size);
// p1 = (a11+a22) * (b11+b22)
add(a21, a22, aResult, new_size); // a21 + a22
STRASSEN_algorithmA(aResult, b11, p2, new_size);
// p2 = (a21+a22) * (b11)
sub(b12, b22, bResult, new_size); // b12 - b22
STRASSEN_algorithmA(a11, bResult, p3, new_size);
// p3 = (a11) * (b12 - b22)
sub(b21, b11, bResult, new_size); // b21 - b11
STRASSEN_algorithmA(a22, bResult, p4, new_size);
// p4 = (a22) * (b21 - b11)
add(a11, a12, aResult, new_size); // a11 + a12
STRASSEN_algorithmA(aResult, b22, p5, new_size);
// p5 = (a11+a12) * (b22)
sub(a21, a11, aResult, new_size); // a21 - a11
add(b11, b12, bResult, new_size);
// b11 + b12
STRASSEN_algorithmA(aResult, bResult, p6, new_size);
// p6 = (a21-a11) * (b11+b12)
sub(a12, a22, aResult, new_size); // a12 - a22
add(b21, b22, bResult, new_size);
// b21 + b22
STRASSEN_algorithmA(aResult, bResult, p7, new_size);
// p7 = (a12-a22) * (b21+b22)
// calculating c21, c21, c11 e c22:
add(p3, p5, c12, new_size); // c12 = p3 + p5
add(p2, p4, c21, new_size); // c21 = p2 + p4
add(p1, p4, aResult, new_size); // p1 + p4
add(aResult, p7, bResult, new_size); // p1 + p4 + p7
sub(bResult, p5, c11, new_size); // c11 = p1 + p4 - p5 + p7
add(p1, p3, aResult, new_size); // p1 + p3
add(aResult, p6, bResult, new_size); // p1 + p3 + p6
sub(bResult, p2, c22, new_size); // c22 = p1 + p3 - p2 + p6
// Grouping the results obtained in a single matrix:
for (i = 0; i < new_size; i++)
{
for (j = 0; j < new_size; j++)
{
C[i][j] = c11[i][j];
C[i][j + new_size] = c12[i][j];
C[i + new_size][j] = c21[i][j];
C[i + new_size][j + new_size] = c22[i][j];
}
}
}
}
void STRASSEN_algorithm(vector<vector<int>> &A, vector<vector<int>> &B, int m, int n, int a, int b)
{
/* Check to see if these matrices are already square and have dimensions of a power of 2. If not,
* the matrices must be resized and padded with zeroes to meet this criteria. */
int k = max({m, n, a, b});
int s = nextpowerof2(k);
vector<int> z(s);
vector<vector<int>> Aa(s, z), Bb(s, z), Cc(s, z);
for (unsigned int i = 0; i < m; i++)
{
for (unsigned int j = 0; j < n; j++)
{
Aa[i][j] = A[i][j];
}
}
for (unsigned int i = 0; i < a; i++)
{
for (unsigned int j = 0; j < b; j++)
{
Bb[i][j] = B[i][j];
}
}
STRASSEN_algorithmA(Aa, Bb, Cc, s);
vector<int> temp1(b);
vector<vector<int>> C(m, temp1);
for (unsigned int i = 0; i < m; i++)
{
for (unsigned int j = 0; j < b; j++)
{
C[i][j] = Cc[i][j];
}
}
display(C, m, b);
}
bool check(int n, int a)
{
if (n == a)
return true;
else
return false;
}
int main()
{
int m, n, a, b;
cout << "Matrix Multiplication using Strassen algorithm" << endl;
cout << "Enter rows and columns of first matrix" << endl;
cin >> m >> n;
cout << "enter values into first matrix" << endl;
vector<vector<int>> A;
//first matrix input
for (int i = 0; i < m; i++)
{
vector<int> temp;
for (int j = 0; j < n; j++)
{
int i;
cin >> i;
temp.push_back(i);
}
A.push_back(temp);
}
cout << "Enter rows and columns of second matrix" << endl;
cin >> a >> b;
cout << "enter values into second matrix" << endl;
vector<vector<int>> B;
//second matrix input
for (int i = 0; i < a; i++)
{
vector<int> temp;
for (int j = 0; j < b; j++)
{
int i;
cin >> i;
temp.push_back(i);
}
B.push_back(temp);
}
bool k = check(n, a);
if (k)
{
STRASSEN_algorithm(A, B, m, n, a, b);
}
else
{
cout << "martrix multiplication not possible";
}
return 0;
}
Applications
Generally Strassen’s Method is not preferred for practical applications for following reasons.
-
The constants used in Strassen’s method are high and for a typical application Naive method works better.
-
For Sparse matrices, there are better methods especially designed for them.
-
The submatrices in recursion take extra space.
-
Because of the limited precision of computer arithmetic on noninteger values, larger errors accumulate in Strassen’s algorithm than in Naive Method