Prim Minimum Spanning Tree Algorithm


Reading time: 15 minutes | Coding time: 9 minutes

Prim's algorithm is a greedy algorithm that finds a minimum spanning tree for a weighted undirected graph. It finds a subset of the edges that forms a tree that includes every vertex, where the total weight of all the edges in the tree is minimized. The algorithm operates by building this tree one vertex at a time, from an arbitrary starting vertex, at each step adding the cheapest possible connection from the tree to another vertex.

prim

Unlike an edge in Kruskal's algorithm, we add vertex to the growing spanning tree in Prim's algorithm.

The credit of Prim's algorithm goes to Vojtěch Jarník, Robert C. Prim and Edsger W. Dijkstra.

Complexity

  • Worst case time complexity: Θ(E log V) using priority queues.
  • Average case time complexity: Θ(E log V) using priority queues.
  • Best case time complexity: Θ(E log V) using priority queues.
  • Space complexity: Θ(E + V)

Steps involved are:

  • Maintain two disjoint sets of vertices. One containing vertices that are in the growing spanning tree and other that are not in the growing spanning tree.
  • Select the cheapest vertex that is connected to the growing spanning tree and is not in the growing spanning tree and add it into the growing spanning tree. This can be done using Priority Queues. Insert the vertices, that are connected to growing spanning tree, into the Priority Queue.
  • Check for cycles. To do that, mark the nodes which have been already selected and insert only those nodes in the Priority Queue that are not marked.

Pseudocode


Prim()
	S = new empty set
	for i = 1 to n
		d[i] = inf
	while S.size() < n
		x = inf
		v = -1
		for each i in V - S // V is the set of vertices
			if x >= d[v]
				then x = d[v], v = i
		d[v] = 0
		S.insert(v)
		for each u in adj[v]
			do d[u] = min(d[u], w(v,u))

Implementations

  • C
  • C++
  • Python

C


/*
 * Part of Cosmos by OpenGenus Foundation
*/
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
// Single node of the graph
typedef struct node
{
	int vert;
	int weight;
    struct node *next;
} node;
// Vertex of the graph
typedef struct vertex
{
	int key;
	int pos;
} vertex;
// Establish a connection of given weight
void connect(node *AdjList, int u, int v, int w)
{
    node *new = (node *)malloc(sizeof(node));
    new->vert = v;
    new->weight = w;
    node *t = AdjList + u;
	new->next = t->next;
	t->next = new;
}
// Function used to propagate heap changes to parent nodes
void heapify2(int heap[], vertex *v, int n, int i)
{
	if(i == 0 || i >= n)
		return;
	if(v[heap[i]].key < v[heap[(i - 1) / 2]].key)
	{
		// Swap elements
		int j = heap[i];
		heap[i] = heap[(i - 1) / 2];
		heap[(i - 1) / 2] = j;
		// Correct positions
		v[heap[i]].pos = i;
		v[heap[(i - 1) / 2]].pos = (i - 1) / 2;
		// Recurse
		heapify2(heap, v, n, (i - 1) / 2);
	}
}
// Main function used when creating heap
void heapify(int heap[], vertex *v, int n, int i)
{
	int l = 2 * i + 1;
	int r = 2 * i + 2;
	int sm = i;
	if(l < n && v[heap[sm]].key > v[heap[l]].key)
		sm = l;
	if(r < n && v[heap[sm]].key > v[heap[r]].key)
		sm = r;
	if(sm != i)
	{
		// Swap elements
		int j = heap[i];
		heap[i] = heap[sm];
		heap[sm] = j;
		// Correct positions
		v[heap[i]].pos = i;
		v[heap[sm]].pos = sm;
		// Recurse
		heapify(heap, v, n, sm);
	}
}
// Remove an element from top of heap
int Hdel(int heap[], vertex *v, int *n)
{
    int k = heap[0];
    v[heap[0]].pos = -1;
    heap[0] = heap[(*n) - 1];
    v[heap[0]].pos = 0;
    (*n)--;
    heapify(heap, v, *n, 0);
    return k;
}
// Main logic for Prim's Algorithm
void Prim(node *AdjList, int m, int n)
{
    int i, tot = 0;
    vertex *v = (vertex *)malloc(sizeof(vertex) * n);
    int *heap = (int *)malloc(sizeof(int) * n);
    // Create heap
	for(i = 0; i < n; ++i)
	{
		v[i].key = INT_MAX;
		v[i].pos = i;
		heap[i] = i;
	}
    v[0].key = 0;
	while(n > 0)
	{
		int vert = Hdel(heap, v, &n);
		tot += v[vert].key;
        node *t = AdjList[vert].next;
        while(t != NULL)
		{
            if(v[t->vert].key > t->weight)
			{
				v[t->vert].key = t->weight;
				heapify2(heap, v, n, v[t->vert].pos);
			}
			t = t->next;
		}
	}
	printf("Weight of MST = %d\n", tot);
}
// Main function
int main()
{
	// m = number of edges, n = number of vertices
	int m = 8, n = 6;
	node *AdjList;
    AdjList = (node *)malloc(sizeof(node) * n);
    /*
	* Create edge connection
    * Since graph is undirected,
    * connections are formed in both
	* directions
    */
	connect(AdjList, 0, 1, 3);
	connect(AdjList, 1, 0, 3);
	connect(AdjList, 1, 4, 5);
	connect(AdjList, 4, 1, 5);
	connect(AdjList, 2, 3, 11);
	connect(AdjList, 3, 2, 11);
	connect(AdjList, 0, 4, 4);
	connect(AdjList, 4, 0, 4);
	connect(AdjList, 1, 2, 7);
	connect(AdjList, 2, 1, 7);
	connect(AdjList, 3, 5, 2);
	connect(AdjList, 5, 3, 2);
	connect(AdjList, 1, 5, 4);
	connect(AdjList, 5, 1, 4);
	connect(AdjList, 2, 4, 5);
	connect(AdjList, 4, 2, 5);
	Prim(AdjList, m, n);
	return 0;
}

C++


#include <iostream>
#include <vector>
#include <utility>
#include <set>
using namespace std;
typedef long long ll;
// Part of Cosmos by OpenGenus Foundation
const int MAXN = 1e4+5;
bool vis[MAXN];
int n, m;
vector<pair<ll, int> > adj[MAXN];   // for every vertex store all the edge weight and the adjacent vertex to it
ll prim(int x){
    // start prim from xth vertex
    multiset<pair<int, int> > S;    // multiset works same as minimum priority queue
    ll minCost = 0;
    S.insert({0, x});
    while(!S.empty()){
        pair<int, int> p = *(S.begin());
        S.erase(S.begin());
        x = p.second;
        if(vis[x])
            continue;
        minCost += p.first;
        vis[x] = true;
        for(size_t i = 0; i < adj[x].size(); i++){
            int y = adj[x][i].second;
            if(!vis[y])
                S.insert(adj[x][i]);
        }
    }
    return minCost;
}
int main(){
    cin >> n >> m;  // n = number of vertices, m = number of edges
    for(int i = 0; i < m; i++){
        int x, y, weight;
        cin >> x >> y >> weight;
        adj[x].push_back({weight, y});
        adj[y].push_back({weight, x});
    }
    // Selecting any node as the starting node
    ll minCost = prim(1);
    cout << minCost << endl;
    return 0;
}

Python


# A Python program for Prim's Minimum Spanning Tree (MST) algorithm.
# The program is for adjacency matrix representation of the graph
# Part of Cosmos by OpenGenus Foundation
class Python():
    def __init__(self, vertices):
        self.V = vertices
        self.graph = [[0 for column in range(vertices)] 
                      for row in range(vertices)]
    # Function to print the constructed MST stored in parent[]
    def printMST(self, parent):
        print ("Edge \tWeight")
        for i in range(1,self.V):
            print (parent[i],"-",i,"\t",self.graph[i][ parent[i] ])
    # Function to find the vertex with minimum distance value, from
    # the set of vertices not yet included in shortest path tree
    def minKey(self, key, mstSet):
        # Initilaize min value
        min = 1000000
        for v in range(self.V):
            if key[v] < min and mstSet[v] == False:
                min = key[v]
                min_index = v
        return min_index
    # Function to construct and print MST for a graph represented using
    # adjacency matrix representation
    def primMST(self):
        #Key values used to pick minimum weight edge in cut
        key = [1000000] * self.V
        parent = [None] * self.V # Array to store constructed MST
        key[0] = 0   # Make key 0 so that this vertex is picked as first vertex
        mstSet = [False] * self.V
        parent[0] = -1  # First node is always the root of
        for cout in range(self.V):
            # Pick the minimum distance vertex from the set of vertices not
            # yet processed. u is always equal to src in first iteration
            u = self.minKey(key, mstSet)
            # Put the minimum distance vertex in the shortest path tree
            mstSet[u] = True
            # Update dist value of the adjacent vertices of the picked vertex
            # only if the current distance is greater than new distance and
            # the vertex in not in the shotest path tree
            for v in range(self.V):
                # graph[u][v] is non zero only for adjacent vertices of m
                # mstSet[v] is false for vertices not yet included in MST
                # Update the key only if graph[u][v] is smaller than key[v]
                if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
                        key[v] = self.graph[u][v]
                        parent[v] = u
        self.printMST(parent)
g  = Python(5)
g.graph = [ [0, 2, 0, 6, 0],
             [2, 0, 3, 8, 5],
             [0, 3, 0, 0, 7],
             [6, 8, 0, 0, 9],
             [0, 5, 7, 9, 0],
           ]
g.primMST();

Applications


Applications of Prim's minimum spanning tree algorithm are:

Alexa Ryder

Alexa Ryder

Hi, I am creating the perfect textual information customized for learning. Message me for anything.

Read More