Tensor Operations: Flatten and Squeeze

Do not miss this exclusive book on Binary Tree Problems. Get it now for free.

Table of Contents
I. What are Tensors?
II. Importance of Flatten and Squeeze
III. Flatten operation
IV. Squeeze operation
V. Conclusion

In this article at OpenGenus, we have explored the 2 fundamental operations in Deep Learning models that modify tensor structure that is Flatten and Squeeze op.

What are Tensors?

Tensors, similar to matrices, are multi-dimensional vector arrays that store data used in Deep Learning models. Tensors are a generalisation of vectors and matrices mapped to higher dimensions. They can be used to represent a wide range of data structures such as scalars, vectors, and matrices, and complex data types like tables, images, frames, audio, and other high-dimensional complex data structures.

In Deep Learning, data can be converted to Tensors, which can be manipulated by processing units such as a GPU or a TPU to feed learning models. For example, an image can be represented as a 3D tensor, with dimensions corresponding to the height, width, and colour channels of the image. Similarly, paragraphs of text can be represented as a 2D tensor where the dimensions correspond to the number of sentences in each paragraph and the length of each sentence.

Tensors provide flexible and efficient ways to represent and manipulate data. Python libraries like TensorFlow and PyTorch are by default the most widely used deep learning libraries which provide operations like multiplication, convolution, flatten, and squeeze.

Flatten and Squeeze, why are they important?

Flatten and Squeeze are two important operations used in the manipulation of tensors in Deep Learning. These operations come into play when we want to reshape tensors. Reshaping operations are extremely important because the layers in a neural network only accept dimensional specific inputs.

For example, a 4D tensor of shape (batch_size, height, width, channels) cannot be fed into a fully connected layer that only accepts two dimensions. So we need to reshape the tensor to represent something like (batch_size, height * width * channels) which is a 2D tensor that can be used as an input to the fully connected layer.

Similarly, sometimes we might need to convert 2D text data that looks like (batch_size, sequence_length) into a 3D tensor for a convolutional layer, which demands a reshape like (batch_size, sequence_length, embedding_size) which can be fed into a convolutional layer.

The Flatten Operation

The Flatten operation is used to convert a multi-dimensional tensor into a one-dimensional tensor. This is done by taking all the elements of the tensor and arranging them in a single dimension. The underlying elements contained in the original tensor still remain in the flattened tensor, but they all are arranged inside a single dimension.

  • When is it used?
    The Flatten operation is often used when we want to feed the data from a convolutional layer to a fully connected layer in the neural network.

Mathematically, the Flatten operation can be thought of as reshaping a tensor of shape (d1, d2, ..., dn) into a tensor of shape (d1 * d2 * ... * dn,).
For example, if we have a tensor of shape (2, 3, 4), applying the Flatten operation would result in a tensor of shape (24,), which is basically a product of the individual elements.

  • Code Implementation
  1. Using the view function:
import torch

x = torch.randn(2, 3, 4)
print(f'Original shape: {x.shape}')
x = x.view(6, 4) # ~(2*3, 4)
print(f'Flattened shape: {x.shape}')

This method creates a tensor x of shape (2, 3, 4) and then uses the view method with the argument (6, 4) which tells PyTorch to infer the size of the missing dimension using the size of input tensor. This gives us a tensor of the shape (6, 4), where the first element is the product of the first two elements of the original tensor.

  1. Using the flatten function:
import torch

x = torch.randn(2, 3, 4)
print(f'Original shape: {x.shape}')
x = x.flatten()
print(f'Flattened shape: {x.shape}')

This method creates a tensor x of shape (2, 3, 4) and then uses the flatten function to reshape into a 1D tensor. This gives us a tensor of the shape (24,), which is the product of the dimensions of original tensor.

The Squeeze Operation

The Squeeze operation, on the other hand, is used to remove dimensions of size 1 from a tensor. This operation can be useful when you have a tensor with unnecessary dimensions that you want to get rid of.

  • When is it used?
    The Squeeze operation is used when we want to remove a set number of dimensions of the size 1 from a tensor.

Mathematically, the Squeeze operation can be thought of as removing all dimensions i from a tensor of shape (d1, d2, ..., dn) where di = 1. For example, if we have a tensor x of shape (1, 3, 1, 5), applying the Squeeze operation would result in a tensor of shape (3, 5).

  • Code Implementation
  1. Using the squeeze function:
import torch

x = torch.randn(1, 3, 1, 5)
print(f'Original shape: {x.shape}')
x = x.squeeze()
print(f'Squeezed shape: {x.shape}')

This method creates a tensor x of shape (1, 3, 1, 5) and then uses the squeeze method to remove the dimensions of size 1. This gives us a tensor of the shape (3, 5).

  • An advantage of this method is, the tensor can be unsqueezed again.
  1. Using the unsqueeze function:
import torch

x = torch.randn(3, 5) # squeezed tensor from previous example
x = x.unsqueeze(0) # ~(1, 3, 5)
x = x.unsqueeze(2) # ~(1, 3, 1, 5
print(f'Unsqueezed shape: {x.shape}')

This method takes the tensor x of shape (3, 5) and then uses the unsqueeze method to add the dimension of size 1 at position 0 and position 2. This gives us back the tensor of shape (1, 3, 1, 5).

Conclusion

Both Flatten and Squeeze operations are essential for building and training Deep Learning models as they allow us to manipulate the shape of our tensors so that they can be fed into different layers of our model for effective learning patterns. See Flatten, and Squeeze for documentation of the two methods.

Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.