Fusion Tree


Reading time: 35 minutes | Coding time: 15 minutes

Fusion tree is a tree data structure that implements associative array in a known universe size. Fusion trees, like Van Emde Boas Trees are used to solve predecessor and successor problem. While Emde Boas tree is useful when universe size is small, fusion tree is used when universe size is large, while providing linear space complexity.

Fusion tree is essentially a B-tree with degree wc where w is the word size and c is some constant smaller than 1. This means the a height of tree will be logwn where n is number of elements stored in the tree.

An important operation used by fusion trees is sketch which is used to compress w bit keys:

fusion-sketch-1

Consider the tree given above representing numbers 17, 21, 23. Let this be the fusion tree. Then the sketch operation would give the bits only at branching positions, which are labeled as b0 and b1 in above figure. In above figure, bit positions 1 and 2 (0 indexed) are the branching levels. So sketch will give bits at these position compressed into a single integer.
For example,
sketch(21) = sketch(10101) = 1 0 1 0 1 = 01
sketch(17) = sketch(10001) = 1 0 0 0 1 = 00
sketch(23) = sketch(10001) = 1 0 1 1 1 = 11
Considering a general case, we will get r sketch bits: b0, b1, ... br - 1.
Let c be 10 here giving a branching factor of k = w1/10, i.e. each node will have w1/10 - 1 keys and w1/10children. This is where sketch operation comes into play. Since it is not possible to compare all O(k) keys of a node in constant time, sketch operation compresses the keys for faster comparisons between all keys within a node allowing for constant time predecessor queries. Since there are k - 1 keys in a node, r is always less than k.

So the number of bits used to store the sketches of all the keys in a node is O(w2/10). Let keys be x0, x1, ... xk-1. Then sketch(x0) < sketch(x1) < ... < sketch(xk-1) since sketch only considers distinguishing bits.

This sketch is called exact sketch.

Algorithm

Sketch operation

  • Sketch(x): In practice, approximate sketch is used instead of exact sketch because exact sketch cannot be computed in constant time. Prefect sketch requires the distinguishing bits to be compressed which a computation expensive operation. Approximate sketch doesn't try to compress the distinguishing bits together tightly, rather it adds a fixed pattern of zeroes between every two distinguishing bits. This is done by multiplication with a predetermined constant. The result will be of order O(r4) where r is count of distinguishing bits.
  1. Let Mask be the mask that will filter out distinguishing bits and m be a predetermined constant.
  2. x' = x AND Mask. After bitwise AND, x' will have only distinguishing bits.
  3. result = x' * m. Right shift result until it is r4 bits long. result is the approximate sketch.

To calculate the constant m, a inductive definition is used. Let bi, 0 <= i <= r - 1, be the distinguishing bit positions and Xbr - 1, Xbr - 2, ... ,Xb0 be set bits of the sketch of b. After multiplication of b with m, each bit at bi will be shifted to location bi + mi such that Xi = bi + mi . The word formed after multiplication must follow following properties:

  1. bi + mj are distinct for all pairs (i, j). This will make sure that no distinguishing bits overlap each other after multiplication.
  2. bi + mi is a strictly increasing with respect to i. This means that order of distinguishing bits, and thus the order of original words is preserved.
  3. (br-1 + mr-1) - (b0 + m0) <= r4, i.e. sketch is r4 bits long.
    The basic strategy to make m is:
    For every t, 0 <= t <= r-1, select smallest value for mt that satisfies mt != bi - bj + mk for every 0 <= i, j <= r-1 and 0 <= k <= t-1.
    This is basically equivalent to bi + mj != bk + ml for 0 <= i, j, k , l <= r-1, thus satisfying condition 1. Selecting smallest integers first satisfies condition 2.
    Since there are r3 possibilities for mt, condition 3 is also satisfied.
    So m will a binary number with positions m0, ..., mr - 1 set as 1, and every 1's will have a padding of r3.

Parallel comparison

  • Parallel comparison: Parallel comparison is used to find position of a value withing the set of keys in a node in constant time. A node will have atmost r + 1 keys. The sketches of all those keys can be concatenated to form a word. This word is used to compare query value q directly with all the keys at once rather than comparing with time one by one.
1. Let X0, X1, ... ,Xr be the keys.
2. Then the sketch of a node is      1sketch(X0)1sketch(X1)... 1sketch(Xr).
3. Convert sketch(a) to sketch(q)' = 0sketch(q)0sketch(q)... 0sketch(q).
4. After subtracting value at step 3 from value at step 2, we will get a number that 
    will be like : 0________0__ ... ___0________1________1________1. Call it diff.
6. There are 0's where sketch of Xi is less than sketch(q) and 1 when
    Xi is greater than sketch(q). The dashes represent garbage bits where
    sketch values used to be. It will always be a series of 0's followed by a series
    of 1's since sketch operation maintains order of keys.The position where the
    series of 0 turns into 1 is the position where sketch(q) should be placed.
    So sketch(Xi) < sketch(q) < sketch(Xi + 1), when bit of diff for 
    sketch(Xi) = 0, and for sketch(Xi + 1) is 1

Note that parallel comparison gives index such that sketch of k is tightly bound with the sketches of keys, not the keys itself.

Predecessor and Successor

  • Predecessor and Successor: This is also called de-sketching. The algorithm below finds predecessor or successor of q in a node of fusion tree. It will be applied recursively to find children nodes until leaf node is reached.
1. Compute sketch(q), which will give approximate sketch of q.
2. Use parallel comparison to get Xi and Xi + 1 such that sketch(Xi) <= sketch(q) <= sketch(Xi + 1).
3. Find y = common prefix of (Xi or Xi + 1) and k.
4. For successor:
    Append y with suffix 1000...000 to fill the non_common bits and store it in a
    variable e. Then e = y100...00.
    For example: let Xi = 0 1 1 0 0 1 0 1 1
                 and k  = 0 1 1 0 0 0 0 0 0
                 then y = 0 1 1 0 0
                 and e  = 0 1 1 0 0 1 0 0 0
   For predecessor:
    Append y with suffix 0111...11 to fill the non_common bits and store it in a
    variable e. Then e = y011...11
5. Use parallel comparison to get Xj and Xj + 1 such that sketch(Xj) < sketch(e) < sketch(Xi + 1).
6. Predecessor of q = Xj.
   Successor of q = Xj + 1

After sketching, the order is maintained only for the keys. There is no gurantee that if sketch(Xi) < sketch(q) < sketch(Xi + 1) then Xi < q < Xi + 1. That is why steps 3 to 5 are used.

Insert

  • insert(k): The insertion is almost exactly same as a normal B-tree insertion.
1. Start with current node as root node.
2. Use parallel search to find appropriate position for k. If current node is not
    a leaf node, set current node as the child at found position. Repeat step 2.
3. Current node is now leaf node. Insert key into correct position. If number of 
    keys is more than max-keys, split the node at middle, and add key at mid
    position to parent node, with it pointing to current node. If parent node
    also exceeds max-keys, keep splitting until root is reached.
    If the root node also exceeds man-keys, split root and make mid key as new 
    root.
4. If the keys in current node have been modified, recalculate node sketch and m.

Complexity

  • Time complexity:
    1. Find, Successor, Predecessor: O(logwN)
    2. Insert: O(log2w)
    3. Delete: O(log2w)
  • Space complexity: O(N)

Where N is count of values stored and W is the size of word, i.e. size of data type used to store the values. Complexity of insertion and deletion depends heavily on the strategy used to implement dynamic properties.

Implementations


Python 3

class Node:
    """Class for fusion tree node"""
    def __init__(self, max_keys = None):
        self.keys = []
        self.children = []
        self.key_count = 0

        self.isLeaf = True
        self.m = 0
        self.b_bits = []    # distinguishing bits
        self.m_bits = []    # bits of constant m
        self.gap = 0
        self.node_sketch = 0
        self.mask_sketch = 0
        self.mask_q = 0     # used in parallel comparison

        self.mask_b = 0
        self.mask_bm = 0

        self.keys_max = max_keys
        if max_keys != None:
            # an extra space is assigned so that splitting can be
            # done easily
            self.keys = [None for i in range(max_keys + 1)]
            self.children = [None for i in range(max_keys + 2)]

class FusionTree:
    """Fusion tree class. initiateTree is called after all insertions in
    this example. Practically, node is recalculated if its keys are
    modified."""
    def getDiffBits(self, keys):
        res = []

        bits = 0
        for i in range(len(keys)):
            if keys[i] == None:
                break;
            for j in range(i):
                w = self.w
                
                while (keys[i] & 1 << w) == (keys[j] & 1 << w) and w >= 0:
                    w -= 1
                if w >= 0:
                    bits |= 1 << w
        
        i = 0
        while i < self.w:
            if bits & (1 << i) > 0:
                res.append(i)
            i += 1
        return res

    def getConst(self, b_bits):
        r = len(b_bits)
        m_bits = [0 for i in range(r)]
        for t in range(r):
            mt = 0
            flag = True
            while flag:
                flag = False
                for i in range(r):
                    if flag:
                        break
                    for j in range(r):
                        if flag:
                            break
                        for k in range(t):
                            if mt == b_bits[i] - b_bits[j] + m_bits[k]:
                                flag = True
                                break
                if flag == True:
                    mt += 1
            m_bits[t] = mt
        
        m = 0
        for i in m_bits:
            m |= 1 << i
        return m_bits, m
                        
    def getMask(self, mask_bits):
        res = 0
        for i in mask_bits:
            res |= 1 << i
        return res

    def initiateNode(self, node):
        if node.key_count != 0:
            node.b_bits = self.getDiffBits(node.keys)
            node.m_bits, node.m = self.getConst(node.b_bits);
            node.mask_b = self.getMask(node.b_bits)

            temp = []
            # bm[i] will be position of b[i] after its multiplication
            # with m[i]. mask_bm will isolate these bits.
            for i in range(len(node.b_bits)):
                temp.append(node.b_bits[i] + node.m_bits[i])
            node.mask_bm = self.getMask(temp);

            # used to maintain sketch lengths
            r3 = int(pow(node.key_count, 3))

            node.node_sketch = 0
            sketch_len = r3 + 1
            node.mask_sketch = 0
            node.mask_q = 0
            for i in range(node.key_count):
                sketch = self.sketchApprox(node, node.keys[i])
                temp = 1 << r3
                temp |= sketch
                node.node_sketch <<= sketch_len
                node.node_sketch |= temp
                node.mask_q |= 1 << i * (sketch_len)
                node.mask_sketch |= (1 << (sketch_len - 1)) << i * (sketch_len)
        return
    
    def sketchApprox(self, node, x):
        xx = x & node.mask_b
        res = xx * node.m

        res = res & node.mask_bm
        return res
        
        

    def __init__(self, word_len = 64, c = 1/5):
        # print(word_len)
        self.keys_max = int(pow(word_len, c))
        self.keys_max = max(self.keys_max, 2)
        self.w = int(pow(self.keys_max, 1/c))
        self.keys_min = self.keys_max // 2

        print("word_len = ", self.w, " max_keys = ", self.keys_max)

        self.root = Node(self.keys_max)
        self.root.isLeaf = True;
    
    def splitChild(self, node, x):
        # a b-tree split function. Splits child of node at x index
        z = Node(self.keys_max)
        y = node.children[x]   # y is to be split

        # pos of key to propagate
        pos_key = (self.keys_max // 2)

        z.key_count = self.keys_max - pos_key - 1

        # insert first half keys into z
        for i in range(z.key_count):
            z.keys[i] = y.keys[pos_key + i + 1]
            y.keys[pos_key + i + 1] = None
        
        if not y.isLeaf:
            for i in range(z.key_count + 1):
                z.children[i] = y.children[pos_key + i + 1]
        
        y.key_count = self.keys_max - z.key_count - 1

        # insert key into node
        node.keys[x] = y.keys[pos_key]
        
        # same effect as shifting all keys after setting pos_key
        # to None
        del y.keys[pos_key]
        y.keys.append(None)

        # insert z as child at x + 1th pos
        node.children[x + 1] = z

        node.key_count += 1

    def insertNormal(self, node, k):
        # print(node, node.keys,'\n', node.key_count)
        # insert k into node when no chance of splitting the root
        if node.isLeaf:
            i = node.key_count
            while i >= 1 and k < node.keys[i - 1]:
                node.keys[i] = node.keys[i - 1]
                i -= 1
            node.keys[i] = k
            node.key_count += 1
            return
        else:
            i = node.key_count
            while i >= 1 and k < node.keys[i - 1]:
                i -= 1
            # i = position of appropriate child

            if node.children[i].key_count == self.keys_max:
                self.splitChild(node, i)
                if k > node.keys[i]:
                    i += 1
            self.insertNormal(node.children[i], k)

    def insert(self, k):
        # This insert checks if splitting is needed
        # then it splits and calls normalInsert

        # if root needs splitting, a new node is assigned as root
        # with split nodes as children
        if self.root.key_count == self.keys_max:
            temp_node = Node(self.keys_max)
            temp_node.isLeaf = False
            temp_node.key_count = 0
            temp_node.children[0] = self.root
            self.root = temp_node
            self.splitChild(temp_node, 0)
            self.insertNormal(temp_node, k)
        else:
            self.insertNormal(self.root, k)

    def successorSimple(self, node, k):
        i = 0
        while i < node.key_count and k > node.keys[i]:
            i += 1
        if i < node.key_count and k > node.keys[i]:
            return node.keys[i]
        elif node.isLeaf:
            return node.keys[i]
        else:
            return self.successor2(node.children[i], k)
    
    def parallelComp(self, node, k):
        # this function should basically give the index such
        # that sketch of k lies between 2 sketches
        sketch = self.sketchApprox(node, k)
        # This will give repeated sketch patterns to allow for comparison
        # in const time
        sketch_long = sketch * node.mask_q

        res = node.node_sketch - sketch_long

        # mask out unimportant bits
        res &= node.mask_sketch

        # find the leading bit. This leading bit will tell position i of
        # such that sketch(keyi-1) < sketch(k) < sketch(keyi)
        i = 0
        while (1 << i) < res:
            i += 1
        i += 1
        sketch_len = int(pow(node.key_count, 3)) + 1
        
        return node.key_count - (i // sketch_len)

    def successor(self, k, node = None):
        if node == None:
            node = self.root

        if node.key_count == 0:
            if node.isLeaf:
                return -1
            else:
                return self.successor(k, node.children[0])
       
        # the corner cases are not concretely defined.
        # other alternative to handle these would be to have
        # -inf and inf at corners of keys array
        if node.keys[0] >= k:
            if not node.isLeaf:
                res = self.successor(k, node.children[0])
                if res == -1:
                    return node.keys[0]
                else:
                    return min(node.keys[0], res)
            else:
                return node.keys[0]
        
        if node.keys[node.key_count - 1] < k:
            if node.isLeaf:
                return -1
            else:
                return self.successor(k, node.children[node.key_count])

        pos = self.parallelComp(node, k)
        # print("pos = ", pos)

        if pos >= node.key_count:
            print(node.keys, pos)
            dump = input()
        
        if pos == 0:
            pos += 1
            # x = node.keys[pos]
        
        # find the common prefix
        # it can be guranteed that successor of k is successor
        # of next smallest element in subtree
        x = max(node.keys[pos - 1], node.keys[pos])
        # print("x = ", x)
        common_prefix = 0
        i = self.w
        while i >= 0 and (x & (1 << i)) == (k & (1 << i)):
            # print(i)
            common_prefix |= x & (1 << i) 
            i -= 1
        if i == -1:
            return x
        
        temp = common_prefix | (1 << i)

        pos = self.parallelComp(node, temp)
        # if pos == 0:
        # possible error?
        #     pos += 1
        # print("pos = ", pos, bin(temp))
        if node.isLeaf:
            return node.keys[pos]
        else:
            res = self.successor(k, node.children[pos])
            if res == -1:
                return node.keys[pos]
            else:
                return res

    def predecessor(self, k, node = None):
        if node == None:
            node = self.root

        if node.key_count == 0:
            if node.isLeaf:
                return -1
            else:
                return self.predecessor(k, node.children[0])
       
        # the corner cases are not concretely defined.
        # other alternative to handle these would be to have
        # 0 and inf at corners of keys array
        if node.keys[0] > k:
            if not node.isLeaf:
                return self.predecessor(k, node.children[0])
            else:
                return -1
        
        if node.keys[node.key_count - 1] <= k:
            if node.isLeaf:
                return node.keys[node.key_count - 1]
            else:
                ret =  self.predecessor(k, node.children[node.key_count])
                return max(ret, node.keys[node.key_count - 1])

        pos = self.parallelComp(node, k)

        if pos >= node.key_count:
            print(node.keys, pos, "ERROR? pos > key_count")
            dump = input()
        
        if pos == 0:
            pos += 1
        
        # find the common prefix
        # it can be guranteed that successor of k is successor
        # of next smallest element in subtree
        x = node.keys[pos]
        common_prefix = 0
        i = self.w
        while i >= 0 and (x & (1 << i)) == (k & (1 << i)):
            common_prefix |= x & (1 << i) 
            i -= 1
        if i == -1:     # i.e. if x is exactly equal to k
            return x
        
        temp = common_prefix | ((1 << i) - 1)
        pos = self.parallelComp(node, temp)
        if pos == 0:
            if node.isLeaf:
                return node.keys[pos]
            res = self.predecessor(k, node.children[1])
            if res == -1:
                return node.keys[pos]
            else:
                return res
                
        if node.isLeaf:
            return node.keys[pos - 1]
        else:
            res = self.predecessor(k, node.children[pos])
            if res == -1:
                return node.keys[pos - 1]
            else:
                return res

    def initiate(self, node):
        if node == None:
            node = Node(self.keys_max)
        self.initiateNode(node)
        if not node.isLeaf:
            for i in range(node.keys_max + 1):
                self.initiate(node.children[i])
    
    def initiateTree(self):
        self.initiate(self.root)


if __name__ == "__main__":
    # create a fusion tree of degree 3
    tree = FusionTree(243)
    
    tree.insert(1)
    tree.insert(5)
    tree.insert(15)
    tree.insert(16)
    tree.insert(20)
    tree.insert(25)
    tree.insert(4)
    print(tree.root.keys)
    for i in tree.root.children:
        if i is not None:
            print (i, " = ", i.keys)
            if not i.isLeaf:
                for j in i.children:
                    if j is not None:
                        print( j.keys)
    tree.initiateTree()
    # the tree formed should be like:
    #      [| 5  | |  16 |]
    #      /      |       \
    #     /       |        \
    # [1, 4]     [15]     [20, 25]
    print("\nKeys stored are:")
    print("1, 4, 5, 15, 16, 20, 25\n")
    print("Predecessors:")
    for i in range(26):
        print(i, "------------------->", tree.predecessor(i), sep = '\t')
    print("Successor:")
    for i in range(26):
        print(i, "------------------->", tree.successor(i), sep = '\t')
    

Applications

  • Fusion tree are used extensively in database systems.
  • Fusion tree and Van Emde boas tree are together. When word size is large, fusion tree are used. When word size is smaller, Van Emde Boas tree is used.

References/ Further reading

  • Try to change insert function to use parallel comparison as an exercise.