Concept Whitening in Machine Learning


Sign up for FREE 1 month of Kindle and read all our books for free.

Neural Networks are one of the most integral part of machine learning. They could be defined as the neural network of artificial neurons or nodes. Because of their capability to perform one of the hardest tasks because of their complicated parameters. This makes them extremely hard to understand.

Many attempts have been done to get a glimpse of internal functioning using the outputs but they couldn't provide accurate results. In a paper published by scientists at Duke University - Zhi Chen, Yijie Bei and Cynthia Rudin, they talked about a new technique called "Concept Whitening".

This technique can be used to find the inner functioning. The difference in its approach is that rather than searching for solutions in the trillions of trained parameters it tries to find answers by interpreting the networks.

The major application of this technique is in Image Recognition.
Image Recognition is an integral application of Computer Vision. When image recognition has to be done, the neural networks extensively work to identify the class/category in which the image fits.

Many images are used to train the model. The features of the images used for training purpose are used encoded in form of numerical values. Then, they are stored in the Latent Space of the model. There are many layers of neural networks that work together to identify and categorize the image. Each layer performs some specification.

Sometimes, models can get trained by wrong features. For example, if we train it with the set of images of Traffic lights with lots of passengers on road, it might predict some image with only passengers and no traffic lights to be a part of "Traffic Lights" category.

The lack of understandability due to complexity of Neural Networks makes it extremely hard to troubleshoot them.

The Concept of "Concept Whitening"

The main logic behind "Concept Whitening" is the development of neural networks such that their Latent Space is aligned with the relevant concepts to what it has been trained for. This will reduce the errors by neural networks.
This method aligns the axis of Latent Space with correct concepts.
Usually, the training of Deep Neural Networks consists of single training dataset that can cause wrong assumptions as mentioned above in case of "Traffic Lights".
However, "Concept Whitening" will introduce second dataset with most relevant set of features that must be there to categorize images. For example, In case of Traffic Lights it will introduce "Red Light", "Yellow Light" and "Green Light" as relevant features and not the passengers.

In general, "Concept Whitening" is all about adding an additional layer of concept to our neural networks.

We can even manually choose the representative samples. Users can indeed create a dataset with the images selected by them that are relevant. This whole process results in a much clearer picture about the functioning of layers of neural networks. Eventually, the models go via two learning cycles in parallel.

The neural network is trained for the general task of recognition and "Concept whitening" just adds adjustment to its specific neurons or nodes in each layer to align them with the classes included in the concept dataset.

After this development, a lot of testing was done for recognition of various types of images and categorizing them. As expected, least mistakes were made by neural networks.

As a result, even if the traffic lights were small and in corner of the image, even then they were correctly recognized by the model.

Concept whitening can be applied to already existing models as well. The technique was developed with a supremely careful approach.

Usually, the neural networks work on the Batch Normalization. It is very popular since it makes required adjustments in the dataset to avoid over-fitting and train the model faster. However, if we implement Concept Whitening then this would additionally adjust the neurons and make models even better.

"Concept Whitening" works on four major calculations for activating concept value:

  • Mean of all feature map values
  • Max of all feature map values
  • Mean of all positive feature map values
  • Mean of down-sampled feature map obtained by max pooling

In the research paper, the case study was done on Skin Lesion Diagnosis. It was found that the results were better in terms of interpretability after the application of concept whitening since new concepts were uncovered in the Latent Space.
The accuracy was indeed a bit improved since the whitening acted as a regularizer.
The separability of concepts also improved significantly.

Arguments with "Concept Whitening" Technique

There are still some issues associated with this technique. Some of them are:

  • Since neural networks are inspired from the structure of human brains, hence many people argue that the best way to interpret inner functioning is by observing the behavior of models. Since, even when researches are conducted on human beings, their behaviors are studied. Contrastingly, Concept Whitening interprets the inner functioning by adding additional dataset.

  • Just like human brains, deep learning models should also learn in fair and evolutionary ways rather than feeding them with specific dataset.
    Interestingly, in all case studies the results were up to the mark with "concept whitening" approach without any compromise with the accuracy of models.

  • Many argue that the focus should remain on the outcomes not on the processing that takes place inside the neural networks.

Conclusion

Keeping in mind all the case studies where "Concept Whitening" has been applied, it showed satisfactory results without any compromise in terms of accuracy. It gave us insights into the internal functioning of Neural Networks that increased interpretability and could potentially help in correcting and enhancing models.
Much improvement could still be done to make this research even insightful. Needless to say, it could definitely improve Artificially Intelligent models. It can help in models used for medical purpose like predictions based on X-Rays etcetera.