Byte Pair Encoding for Natural Language Processing (NLP)


Byte Pair Encoding is originally a compression algorithm that was adapted for NLP usage.

One of the important steps of NLP is determining the vocabulary.

There are different ways to model the vocabularly such as using an N-gram model, a closed vocabularly, bag of words, and etc. However, these methods are either very computationally memory heavy (usually paired with a large vocabularly size), do not handle OOV (out-of-vocabulary) words very well, or run into issues with words that rarely appear. Often times, many models will simply replace infrequently seen / unseen words with a special unknown token or ignore it. Those solutions will cause information loss which can be extremely important for certain NLP applications.

For those who are unfamilar with OOV, in short, OOV words are words that have not been seen in a model before after training.

Byte Pair Encoding comes in handy for handling the vocabulary issue through a bottom-up process where subwords are created from the initial vocabularly of characters.

  • Subword example: Computer -> com pu t er

Benefits

  • Can set the vocabulary size (if feasible based on the given input data)
    • Vocabulary is automatically decided for you.
    • Vocabulary size is generally lower than other models.
      • Thus more scalability for machine learning
  • OOV is improved
  • Relationships between subwords can be inferred
  • Can suggest neogolism (new unseen words)

Applications

  • Used in different models, such as Neural Language Models (NLM)
    • NLM useful for: autocompletion, program repair, code readability, and etc.
  • Helpful in word segmentation and creating subword model.
  • The process can have variance (Google NMT such as wordpiece model, sentencepiece model)

Algorithm

Lets assume all the preprocessing work has been done in which all the words have been normalized and each word has a frequency count.

  • input: [ { low : 5 }, { lower : 2 }, { newest : 6 }, { widest : 3 } ]
    • An array of words with their associated frequency count.

Often times, there is an attached suffix token at the end of each word to help indicate when the word/subword is at the end.

  • Ex: low -> low<\w>
  1. Each word is character splitted while maintaining it's frequency count
    • Ex: { low : 5 } -> { l o w : 5 }, { lower : 2 } -> { l o w e r : 2 }
  2. Add each character to the vocabularly with their total frequency count (running sum from all the input data)
    • Ex: { { l : 7 }, { o : 7 }, { w : 16 } ... }
  3. Find the pair of subwords with the highest summed frequency together
    • Ex: the pair (e, s) has a frequency of 9
      • This is because { n e w e s t : 6 } and { w i d e s t : 3 } -> 6 + 3 = 9, there was a tie (e, s) and (s, t) but we decided to pick (e, s).
  4. From step 3 results, add the results to the vocabulary with the frequency along with updating that pair for each word that contains that pair.
    • Vocabulary should have all the letter frequencies and { es : 9 }
    • Updated subwords should be: { n e w es t : 6 }, { w i d es t : 3 }
  5. Repeat steps 3-4 until X amount of iterations have been completed or there are no more subword pairs to find.

Code

The code below is an example of the algorithm mentioned without the vocabaulary but can easily be modified to include a vocabulary that was built up from the algorithm.

import java.util.*;

public class BytePairEncodingSubwords {
	private static Map<Collection<String>, Integer> getPairFreq(Map<String, Integer> dictionary) {
		Map<Collection<String>, Integer> pairFreqs = new HashMap<>();
		for (Entry<String, Integer> wordFreq : dictionary.entrySet()) {
			String[] subwords = wordFreq.getKey().split(" ");
			int freq = wordFreq.getValue();
			
			//creating all the possible pairs of subwords while mapping their associated count
			for (int i = 0; i < subwords.length - 1; i++) {
				List<String> pairOfSubwords = Arrays.asList(subwords[i], subwords[i + 1]);
				pairFreqs.compute(pairOfSubwords, (k, v) -> v != null ? v + freq : freq);
			}
		}
		return pairFreqs;
	}

	private static Collection<String> getMax(Map<Collection<String>, Integer> pairFreqs) {
		if (pairFreqs.isEmpty()) {
			return null;
		}
		//Determine the pair of subwords that has the highest count
		return Collections.max(pairFreqs.entrySet(), (e1, e2) -> e1.getValue().compareTo(e2.getValue())).getKey();
	}

	private static Map<String, Integer> merge(Collection<String> replacementPair, Map<String, Integer> dict) {
		Map<String, Integer> updatedDict = new HashMap<>(dict.size());
		String[] replacementPairInfo = replacementPair.toArray(new String[2]);
		String replacementPattern = replacementPairInfo[0] + " " + replacementPairInfo[1];
		String replacementSubword = replacementPairInfo[0] + replacementPairInfo[1];
		//Update each subword that has the pair of subwords by concatenating them together. 
		dict.entrySet().forEach(entry -> updatedDict.put(entry.getKey().replaceAll(replacementPattern, replacementSubword), entry.getValue()));
		return updatedDict;
	}

	private static Map<String, Integer> encode(Map<String, Integer> dict, int iterations) {
		while (iterations > 0) {
			Map<Collection<String>, Integer> pairings = getPairFreq(dict); 
			Collection<String> bestPair = getMax(pairings);
			if (bestPair == null) { //No more pair of subwords can be generated (all are full words)
				break;
			}
			dict = merge(bestPair, dict); //update the dictionary
			iterations--;
		}
		return dict;
	}

	private static Map<String, Integer> decode(Map<String, Integer> dict) {
		Map<String, Integer> decodedDict = new HashMap<>(dict.size());
		dict.entrySet().forEach(entry -> decodedDict.put(entry.getKey().replaceAll(" ", ""), entry.getValue())); //joins the subwords together that are delimited by spaces
		return decodedDict;
	}

	public static void main(String[] args) {
		Map<String, Integer> dict = getBasicTestVocab();
		System.out.println(dict);
		dict = encode(dict, 10);
		System.out.println(dict);
		dict = decode(dict);
		System.out.println(dict);
	}

	private static Map<String, Integer> getBasicTestVocab() {
		Map<String, Integer> dict = new HashMap<>();
		dict.put("l o w", 5);
		dict.put("l o w e r", 2);
		dict.put("n e w e s t", 6);
		dict.put("w i d e s t", 3);
		return dict;
	}
}

Things to take note from this example:

  1. There is no vocabulary in the example but it would not be hard to pass around a vocabulary dictionary along with a limit check for the vocabulary dictionary to stop the algorithm. So step 1 was skipped but should be pretty straightforward.
  2. If you change the line in main: encode(dict, 10) -> encode(dict, 1) or encode(dict, 2) you can see the process of the algorithm in which pair of subwords that were selected, concatenated and would be added to the vocabularly if there was one.
  3. There comes a point to where all the subword pairings have been added and there are no more pairings left to select. So there is a check for that but also iterations is necessary to determine how many times this algorithm should run.
  4. Lastly, there is the decode method in which you can see how easy it is to return the input back into our original input which is just concatenating the subwords together.

References