Calculate mean and std of Image Dataset
Do not miss this exclusive book on Binary Tree Problems. Get it now for free.
In this article, we have explained how to calculate the mean and standard deviation (std) of an image dataset which can be used to normalize images in the dataset for effective training of Neural Networks.
The challenge is to compute mean and std in batches as loading the entire image dataset will have significant memory overhead. We have presented the Python code using Pytorch.
Table of contents:
- Why we need mean and std of Image Dataset?
- Calculate mean and std of Image Dataset
- Code to calculate mean and std of dataset in Pytorch/ Python
- Concluding Note
Following table summarizes the idea of calculating mean and std of Image Dataset:
Point | Mean and STD of Image Dataset |
---|---|
Why needed? | Need to normalize Image Dataset to improve training accuracy of Neural Networks. |
Formula | |
Challenge | The challenge is to compute mean and std in batch of images |
Any accuracy loss in batch computation? | The mean and std values are exact even when computed in batches. |
Solution |
|
Why we need mean and std of Image Dataset?
When we are dealing with an Image dataset, you need to normalize the images in the dataset before training a Neural Network on it. This is required because of the following two core reasons:
- It helps the trained Neural Network give consistent results for new test images.
- Helps in Transfer Learning
There are several image datasets such as ImageNet, EDD dataset that are used for training Neural Networks such as ResNet50, GoogleNet, MobileNetV1, RefineDet and much more.
To normalize an image, one will need the mean and standard deviation of the entire dataset. We will see how to calculate the mean and standard deviation (std) in the next section of this article at OpenGenus.
Calculate mean and std of Image Dataset
Following is the formula to calculate the mean (average) and standard deviation (std):
An image is a data of 4 dimensions (B, C, H, W) where:
- B is batch size that is number of images
- C is the number of channels in the image which will be 3 for RGB images.
- H is the height of each image
- W is the width of each image
The mean and std will be of 1 dimension with C values each for a channel. For RGB images, the mean will be a set of 3 values. Similarly, std will be a set of three values.
So, mean and std is calculated separately for each channel.
The challenge is that we cannot load the entire dataset into memory to calculate these paramters. We can load a small set of images one by one and this can make the computation of mean and std non-trivial.
The steps will be as follows:
- Define two variables of 3 elements total_sum and total_sum_square
- total_sum = (0, 0, 0) and total_sum_square = (0, 0, 0)
- total_images = Number of images in dataset
- Load a set of B images from the dataset
- Get the sum of all images in the current set and add it to total_sum
- Get the sum of square of all images in current set and add it to total_sum_square
- Once all images in the dataset has been processed, move to next step.
- mean = total_sum / total_images
- std = total_sum_square / total_images - mean * mean
With this, mean and std are calculated in batches.
Code to calculate mean and std of dataset in Pytorch/ Python
In this section, we have explored the Python code using Pytorch to calculate the mean and standard deviation (std) of a loaded dataset.
In the following Python code snippet, we create 2 tensors for total_sum and total_sum_square and loop through all batches of images and add up the required values. Note that we pass the axis as [0, 2, 3] and have skipped 1 as it is the channel dimension.
total_sum = torch.tensor([0.0, 0.0, 0.0])
total_sum_square = torch.tensor([0.0, 0.0, 0.0])
for inputs in tqdm(image_loader):
total_sum += inputs.sum(axis = [0, 2, 3])
total_sum_square += (inputs ** 2).sum(axis = [0, 2, 3])
In the following Python code snippet, we use the 2 tensors to calculate the mean and standard deviation (std) of the image dataset:
# len(df) = Number of images in dataset (df)
count = len(df) * image_size * image_size
# mean and std
total_mean = total_sum / count
total_var = (total_sum_square / count) - (total_mean ** 2)
total_std = torch.sqrt(total_var)
# output
print('mean: ' + str(total_mean))
print('std: ' + str(total_std))
Output:
mean: tensor([0.4417, 0.5110, 0.3178])
std: tensor([0.2330, 0.2358, 0.2247])
Once you have calculated mean and std, we can normalize the entire dataset using Normalize() method. Following Python code snippet using Pytorch demonstrates the process:
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
df = pd.read_csv("/path" + 'train.csv')
augmentations = A.Compose([A.Resize(height = image_size,
width = image_size),
A.Normalize(mean = total_mean,
std = total_std),
ToTensorV2()])
normalized_image_dataset = EDDdata(data = df,
directory = data_path + 'train_images/',
transform = augmentations)
The EDDdata class is defined as follows:
class EDDdata(Dataset):
def __init__(self,
data,
directory,
transform = None):
self.data = data
self.directory = directory
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# import
path = os.path.join(self.directory, self.data.iloc[idx]['image_id'])
image = cv2.imread(path, cv2.COLOR_BGR2RGB)
# augmentations
if self.transform is not None:
image = self.transform(image = image)['image']
return image
This completes the entire process and Python code using Pytorch to calculate mean and std of entire image dataset and use the two values to normalized the entire image dataset.
Concluding Note
With this article at OpenGenus, you must have the complete idea of how to calculate mean (average) and standard deviation (std) of image dataset and how to use it to normalize images before training of Neural Network.
Loading entire dataset is always a challenge but several parameters can be computed by computing small batches of images one by one to cover the entire dataset.
Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.