as_strided op in PyTorch
Do not miss this exclusive book on Binary Tree Problems. Get it now for free.
So you've now finally begun working towards building your first network in PyTorch. You're working with tensors in the dataset, you wish to alter their shape, or do some form of operations. The as_strided() function in PyTorch can be very useful for this.
Table of Contents:
- What is as_strided()?
- CNN operations with as_strided()
- Cautions with as_strided()
- Test your knowledge
- Key takeaways
The pytorch as_strided function
Key Questions:
- What does as_strided do?
- How do you determine the stride of as_strided?
- Breakdown an example of convolution using as_strided.
- Why is it dangerous to use as_strided?
What does as_strided do?
as_strided is used to pick out a specific shaped Tensor for you.
The syntax goes as follows:
as_strided(input, size, stride, storage_offset)
Where:
- input = Is an input tensor
- size = The shape of the output tensor you wish to make (Specified by Tuple of Ints / Int)
- stride = The types of steps you make while creating your output tensor (Specified by Tuple of Ints / Int)
- storage_offset = Start position- Number of tensors as an offset you want to use before starting the creation.
Basic Example
Say I have 3x3 matrix A
$
A = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}
$
I wish to create a 2x2 matrix B which is a subset of A
$
B = \begin{bmatrix} 1 & 2 \\ 4 & 5 \end{bmatrix}
$
To do this, we see that all we need to do is to move between rows of our matrix I.e.
- Go to element 1, go to the next row to retrieve 4
- Go to element 2, go to the next row to retrieve 5
as_strided interprets the matrix as a long list, so it’s basically looking at A with representation:
# As strided looks at these matrices as:
A = [1, 2, 3, 4, 5, 6]
Which means, to move between rows, we need to set a stride equal to the length of one row of the matrix. In this case, it is 3.
Implementing this with as_strided then goes as follows:
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.as_strided(A, (2,2), (1, 3))
Breaking Down the stride tuple:
You may be wondering how the stride tuple works in this function call, here’s the breakdown:
- We can treat the stride operation as retrieving elements, we have a window that constantly slides. In the previous example, the window points to a single point- The element 1.
- The first element of the stride tuple then specifies how much that window should move. In this case, we want to move horizontally one by one. I’ll call this i.
- The second element of the stride tuple species the amount of steps taken before retrieving an element. I.e. We are at element 1, we now wish to get the i + 3th index.
Later on, we’ll see more complicated examples of this, but now that you get the gist, we can get into more practical examples.
A more practical example of as_strided:
Say for example, we wish to perform a matrix trace on a 3x3 matrix I.e.
$
X = \begin{bmatrix}
1 & 2 & 3 \
4 & 5 & 6 \\
7 & 8 & 9
\end{bmatrix}
$
Our tensor representation of this would be:
X = torch.tensor([[1, 2, 3], [4,5,6], [7,8,9]])
print(X)
tensor([[1, 2, 3],
[4, 5, 6]
[7, 8, 9]])
A matrix trace is essentially the sum of all elements on it’s diagonals.
Mathematically, this can be represented as:
$
\sum_{i = 1}^3 a_{ii}= a_{11} + a_{22} + a_{33}
$
This can be implemented using some form of for loop:
x_size = len(X[0])
trace = 0
for x in range(3):
idx = x * x_size + x
trace += m.flatten()[idx]
To convert this into an as_strided form, we can differentiate the equation for the index. I.e.
If our for loop equation is:
$
index = x * |x| + x
$
Differentiating with respect to x gives us:
$
\frac{\partial{index}}{\partial{x}} = |x| + 1
$
We can then sub this in for our as_strided implementation
n = m.clone()
x_size = len(X[0])
trace = torch.as_strided(n, (x_size,), (x_size + 1,)).sum()
Outer Product
The outer product can be described as:
$
\textbf{u} = \begin{pmatrix} u_1 \\ u_2 \\ . \. \\ u_m \end{pmatrix}, \textbf{v} = \begin{pmatrix} v_1 \\ v_2 \\ . \. \\ v_n \end{pmatrix}
$
$
\textbf{u} \otimes \textbf{v} = \begin{bmatrix} u_1v_1 & u_1v_2 & ... & u_1v_n \\ u_2v_1 & u_2v_2 & ... & u_2v_n \\ . & . & . & . \
u_mv_1 & u_mv_2 & ... & u_mv_n
\end{bmatrix}
$
Essentially, each row is the result of one matrix multiplication between u1 and v.
To implement this, we can once again use a for loop approach I.e.
u = torch.randn(3)
v = torch.randn(3)
for i in range(3):
for j in range(3):
outer[i,j] = u[i] * v[j]
Our index equations are then:
$
idx_u = i + 0 \cdot j, idx_v = 0 \cdot i + j
$
$
\frac{\partial{idx_u}}{\partial{i}} = 1, \frac{\partial{idx_u}}{\partial{j}} = 0
$
Implementing this with as_strided we get:
outer = torch.as_strided(u, (3,3), (1,0)) * torch.as_strided(v, (3,3), (1,0))
Matrix Multiplication
A matrix multiplication is then a 3d outer product followed by a summation:
a = torch.randn(16, 16)
b = torch.randn(16, 16)
mm_0 = torch.mm(a, b)
mm_1 = (torch.as_strided(a, (16, 16, 16), (16, 1, 0)) *
torch.as_strided(b, (16, 16, 16), (0, 16, 1))).sum(1)
Convolution
Ah, convolution. One of the main parts when developing a neural network for Deep Learning. This, too can be implemented by using as_strided().
For a quick revision, convolution is the process where you gather aggregate a group of values through a window called the kernel to gather features for the network to learn on.
An example of this can be seen below:
Let’s try to implement this!
So in the animation above, we have a 5x5 matrix, and we wish to perform convolution using a 3x3 kernel with stride = 1 (I.e. the Kernel moves 1 pixel at a time)
A = torch.randn(5, 5)
conv = torch.as_strided(A, (3,3,3,3), (5, 1, 5, 1)).sum(axis = (2,3))
Breakdown:
- The output shape (3, 3, 3, 3) represents
- The first (3,3) represents a 3x3 matrix, with each element in the matrix representing a (3x3 matrix) from the kernel, the second (3,3).
- The strides (5, 1, 5, 1) represents
- First element of 5 represents vertical movements of 1
- Second element of 1 represents horizontal movements of 1.
- Third element of 5 represents the vertical movements in the kernel.
- Fourth element of 1 represents horizontal movements in the kernel.
The cautionary tale of as_strided
Let's go back to the documentation of this function, if we have a look we have this warning:
Prefer using other view functions, like torch.Tensor.expand(), to setting a view’s strides manually with as_strided, as this function’s behavior depends on the implementation of a tensor’s storage. The constructed view of the storage must only refer to elements within the storage or a runtime error will be thrown, and if the view is “overlapped” (with multiple indices referring to the same element in memory) its behavior is undefined.
This function can have many issues:
- It can view outside the range of an existing tensor
- It can lead to failure to compile when used in combination of other view operations
- As the warning precautions, there can be a runtime errors when dealing with this due to overlaps in memory.
Therefore, maybe consider other view functions if you don't plan to do very specific operations.
Test your Knowledge
As usual, here's a mini quiz to check up on your knowledge!
Question 1:
Which of the following best describes the functionality of as_strided?
Question 2
Consider the following matrices:
A = torch.tensor([[1, 2, 3],
[4, 5, 6]])
B = torch.tensor([[2, 5],
[3, 6]])
Which call of as_strided()
will successfully create matrix B?
Main Takeaways
- as_strided is a function used to create views on Tensors
- We can modify the parameters of the function call to get desired outputs (I.e. Changing the output size, stride size)
- Using these concepts, we are able to implement various techniques such as outer product, matrix multiplication and even convolution.
Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.