Prim Minimum Spanning Tree Algorithm
Sign up for FREE 1 month of Kindle and read all our books for free.
Get FREE domain for 1st year and build your brand new site
Reading time: 15 minutes  Coding time: 9 minutes
Prim's algorithm is a greedy algorithm that finds a minimum spanning tree for a weighted undirected graph. It finds a subset of the edges that forms a tree that includes every vertex, where the total weight of all the edges in the tree is minimized. The algorithm operates by building this tree one vertex at a time, from an arbitrary starting vertex, at each step adding the cheapest possible connection from the tree to another vertex.
Unlike an edge in Kruskal's algorithm, we add vertex to the growing spanning tree in Prim's algorithm.
The credit of Prim's algorithm goes to VojtÄ›ch JarnÃk, Robert C. Prim and Edsger W. Dijkstra.
Complexity
 Worst case time complexity:
Î˜(E log V)
using priority queues.  Average case time complexity:
Î˜(E log V)
using priority queues.  Best case time complexity:
Î˜(E log V)
using priority queues.  Space complexity:
Î˜(E + V)
Steps involved are:
 Maintain two disjoint sets of vertices. One containing vertices that are in the growing spanning tree and other that are not in the growing spanning tree.
 Select the cheapest vertex that is connected to the growing spanning tree and is not in the growing spanning tree and add it into the growing spanning tree. This can be done using Priority Queues. Insert the vertices, that are connected to growing spanning tree, into the Priority Queue.
 Check for cycles. To do that, mark the nodes which have been already selected and insert only those nodes in the Priority Queue that are not marked.
Pseudocode
Prim()
S = new empty set
for i = 1 to n
d[i] = inf
while S.size() < n
x = inf
v = 1
for each i in V  S // V is the set of vertices
if x >= d[v]
then x = d[v], v = i
d[v] = 0
S.insert(v)
for each u in adj[v]
do d[u] = min(d[u], w(v,u))
Implementations
 C
 C++
 Python
C
/*
* Part of Cosmos by OpenGenus Foundation
*/
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
// Single node of the graph
typedef struct node
{
int vert;
int weight;
struct node *next;
} node;
// Vertex of the graph
typedef struct vertex
{
int key;
int pos;
} vertex;
// Establish a connection of given weight
void connect(node *AdjList, int u, int v, int w)
{
node *new = (node *)malloc(sizeof(node));
new>vert = v;
new>weight = w;
node *t = AdjList + u;
new>next = t>next;
t>next = new;
}
// Function used to propagate heap changes to parent nodes
void heapify2(int heap[], vertex *v, int n, int i)
{
if(i == 0  i >= n)
return;
if(v[heap[i]].key < v[heap[(i  1) / 2]].key)
{
// Swap elements
int j = heap[i];
heap[i] = heap[(i  1) / 2];
heap[(i  1) / 2] = j;
// Correct positions
v[heap[i]].pos = i;
v[heap[(i  1) / 2]].pos = (i  1) / 2;
// Recurse
heapify2(heap, v, n, (i  1) / 2);
}
}
// Main function used when creating heap
void heapify(int heap[], vertex *v, int n, int i)
{
int l = 2 * i + 1;
int r = 2 * i + 2;
int sm = i;
if(l < n && v[heap[sm]].key > v[heap[l]].key)
sm = l;
if(r < n && v[heap[sm]].key > v[heap[r]].key)
sm = r;
if(sm != i)
{
// Swap elements
int j = heap[i];
heap[i] = heap[sm];
heap[sm] = j;
// Correct positions
v[heap[i]].pos = i;
v[heap[sm]].pos = sm;
// Recurse
heapify(heap, v, n, sm);
}
}
// Remove an element from top of heap
int Hdel(int heap[], vertex *v, int *n)
{
int k = heap[0];
v[heap[0]].pos = 1;
heap[0] = heap[(*n)  1];
v[heap[0]].pos = 0;
(*n);
heapify(heap, v, *n, 0);
return k;
}
// Main logic for Prim's Algorithm
void Prim(node *AdjList, int m, int n)
{
int i, tot = 0;
vertex *v = (vertex *)malloc(sizeof(vertex) * n);
int *heap = (int *)malloc(sizeof(int) * n);
// Create heap
for(i = 0; i < n; ++i)
{
v[i].key = INT_MAX;
v[i].pos = i;
heap[i] = i;
}
v[0].key = 0;
while(n > 0)
{
int vert = Hdel(heap, v, &n);
tot += v[vert].key;
node *t = AdjList[vert].next;
while(t != NULL)
{
if(v[t>vert].key > t>weight)
{
v[t>vert].key = t>weight;
heapify2(heap, v, n, v[t>vert].pos);
}
t = t>next;
}
}
printf("Weight of MST = %d\n", tot);
}
// Main function
int main()
{
// m = number of edges, n = number of vertices
int m = 8, n = 6;
node *AdjList;
AdjList = (node *)malloc(sizeof(node) * n);
/*
* Create edge connection
* Since graph is undirected,
* connections are formed in both
* directions
*/
connect(AdjList, 0, 1, 3);
connect(AdjList, 1, 0, 3);
connect(AdjList, 1, 4, 5);
connect(AdjList, 4, 1, 5);
connect(AdjList, 2, 3, 11);
connect(AdjList, 3, 2, 11);
connect(AdjList, 0, 4, 4);
connect(AdjList, 4, 0, 4);
connect(AdjList, 1, 2, 7);
connect(AdjList, 2, 1, 7);
connect(AdjList, 3, 5, 2);
connect(AdjList, 5, 3, 2);
connect(AdjList, 1, 5, 4);
connect(AdjList, 5, 1, 4);
connect(AdjList, 2, 4, 5);
connect(AdjList, 4, 2, 5);
Prim(AdjList, m, n);
return 0;
}
C++
#include <iostream>
#include <vector>
#include <utility>
#include <set>
using namespace std;
typedef long long ll;
// Part of Cosmos by OpenGenus Foundation
const int MAXN = 1e4+5;
bool vis[MAXN];
int n, m;
vector<pair<ll, int> > adj[MAXN]; // for every vertex store all the edge weight and the adjacent vertex to it
ll prim(int x){
// start prim from xth vertex
multiset<pair<int, int> > S; // multiset works same as minimum priority queue
ll minCost = 0;
S.insert({0, x});
while(!S.empty()){
pair<int, int> p = *(S.begin());
S.erase(S.begin());
x = p.second;
if(vis[x])
continue;
minCost += p.first;
vis[x] = true;
for(size_t i = 0; i < adj[x].size(); i++){
int y = adj[x][i].second;
if(!vis[y])
S.insert(adj[x][i]);
}
}
return minCost;
}
int main(){
cin >> n >> m; // n = number of vertices, m = number of edges
for(int i = 0; i < m; i++){
int x, y, weight;
cin >> x >> y >> weight;
adj[x].push_back({weight, y});
adj[y].push_back({weight, x});
}
// Selecting any node as the starting node
ll minCost = prim(1);
cout << minCost << endl;
return 0;
}
Python
# A Python program for Prim's Minimum Spanning Tree (MST) algorithm.
# The program is for adjacency matrix representation of the graph
# Part of Cosmos by OpenGenus Foundation
class Python():
def __init__(self, vertices):
self.V = vertices
self.graph = [[0 for column in range(vertices)]
for row in range(vertices)]
# Function to print the constructed MST stored in parent[]
def printMST(self, parent):
print ("Edge \tWeight")
for i in range(1,self.V):
print (parent[i],"",i,"\t",self.graph[i][ parent[i] ])
# Function to find the vertex with minimum distance value, from
# the set of vertices not yet included in shortest path tree
def minKey(self, key, mstSet):
# Initilaize min value
min = 1000000
for v in range(self.V):
if key[v] < min and mstSet[v] == False:
min = key[v]
min_index = v
return min_index
# Function to construct and print MST for a graph represented using
# adjacency matrix representation
def primMST(self):
#Key values used to pick minimum weight edge in cut
key = [1000000] * self.V
parent = [None] * self.V # Array to store constructed MST
key[0] = 0 # Make key 0 so that this vertex is picked as first vertex
mstSet = [False] * self.V
parent[0] = 1 # First node is always the root of
for cout in range(self.V):
# Pick the minimum distance vertex from the set of vertices not
# yet processed. u is always equal to src in first iteration
u = self.minKey(key, mstSet)
# Put the minimum distance vertex in the shortest path tree
mstSet[u] = True
# Update dist value of the adjacent vertices of the picked vertex
# only if the current distance is greater than new distance and
# the vertex in not in the shotest path tree
for v in range(self.V):
# graph[u][v] is non zero only for adjacent vertices of m
# mstSet[v] is false for vertices not yet included in MST
# Update the key only if graph[u][v] is smaller than key[v]
if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
key[v] = self.graph[u][v]
parent[v] = u
self.printMST(parent)
g = Python(5)
g.graph = [ [0, 2, 0, 6, 0],
[2, 0, 3, 8, 5],
[0, 3, 0, 0, 7],
[6, 8, 0, 0, 9],
[0, 5, 7, 9, 0],
]
g.primMST();
Applications
Applications of Prim's minimum spanning tree algorithm are:

Used to find the Minimum Spanning Tree using a greedy approach