Lets say we need to connect some objects and query if some objects are connected or not. Below is an example test:
@Test
public void testConnectivity() {
MyConnectionDb connectDb = new MyConnectionDb();
connectDb.connect(4, 3);
connectDb.connect(3, 8);
connectDb.connect(6, 5);
connectDb.connect(9, 4);
connectDb.connect(2, 1);
assertTrue(connectDb.isConnected(8, 9));
assertFalse(connectDb.isConnected(5, 4));
connectDb.connect(5, 0);
connectDb.connect(7, 2);
connectDb.connect(6, 1);
connectDb.connect(7, 3);
assertTrue(connectDb.isConnected(5, 3));
}
So when some numbers are connected, they form a set and we enlarge that set by applying union operations. And of course we query if given two numbers are connected or not.
I started with a solution using Java collections like HashSet. It is good to get a feel of the problem and come up with a working solution. Then we will see why it is not fast enough and how we can improve it.
public class MyFirstConnectDb {
List> sets = new ArrayList<>();
// O(number of sets) and O(nElements in mergeSet)
public void connect(int o1, int o2) {
boolean connectionDone = false;
List> setsToMerge = new ArrayList<>();
// O(nSet)
for (Set set : sets) {
// System.out.println("Set count: " + sets.size());
// O(1)
if (set.contains(o1)) {
set.add(o2);
connectionDone = true;
setsToMerge.add(set);
continue;
}
if (set.contains(o2)) {
set.add(o1);
connectionDone = true;
setsToMerge.add(set);
}
}
// find sets to merge and merge them
Set mergedSet = new HashSet<>();
// O(nSetsToMerge)
for (Set item : setsToMerge) {
// System.out.println("Element count in setToMerge : " + item.size());
// O(nElements)
for (Integer obj : item) {
// O(1)
mergedSet.add(obj);
}
}
if (!mergedSet.isEmpty()) {
sets.add(mergedSet);
}
for (Set item : setsToMerge) {
sets.remove(item);
}
if (connectionDone) {
return;
}
Set newSet = new HashSet<>();
newSet.add(o1);
newSet.add(o2);
sets.add(newSet);
}
public void connect(List elements) {
if (elements.size() != 2) {
throw new ConnectDbException("Cannot connect elements: wrong number of elements: " + elements.size());
}
connect(elements.get(0), elements.get(1));
}
public boolean isConnected(int o1, int o2) {
// O(set size)
for (Set set : sets) {
if (set.contains(o1) && set.contains(o2)) {
return true;
}
}
return false;
}
public void connectAll(List list) {
for (int i = 0; i < list.size() - 1; ++i) {
connect(i, i + 1);
}
}
}
Now lets run the following performance test on that implementation:
@Test
public void test() {
MyFirstConnectDb connectDb = new MyFirstConnectDb();
List list = new ArrayList<>();
for (int i = 0; i < 1000000; ++i) {
list.add(i);
}
connectDb.connectAll(list);
long startTime = System.nanoTime();
assertTrue(connectDb.isConnected(0, 999999));
long stopTime = System.nanoTime();
System.out.println(stopTime - startTime);
}
For 1 million connections, the speed is not acceptable.
Now lets try the first version of the Union-Find algorithm. We use an array as a data type. The index of the array is the number itself and the value of at that index is updated when there is a connection.
public class UnionFindConnectDb {
ArrayList array;
public UnionFindConnectDb(int n) {
array = new ArrayList<>(n);
for (int i = 0; i < n; ++i) {
array.add(i, i);
}
}
public void connect(int o1, int o2) {
int root1 = findRoot(o1);
int root2 = findRoot(o2);
array.set(root1, array.get(root2));
}
private int findRoot(int x) {
while (array.get(x) != x) {
x = array.get(x);
}
return x;
}
public boolean isConnected(int o1, int o2) {
int root1 = findRoot(o1);
int root2 = findRoot(o2);
return root1 == root2;
}
public void connectAll(List list) {
for (int i = 0; i < list.size() - 1; ++i) {
connect(i, i + 1);
}
}
}
If we run the performance test on this implementation we see that the result is way way better. Given the roots, connect operation has constant time complexity. And the same is true for isConnected query. We just look if their roots are the same. One drawback of this solution is that the tree can be very very tall in the worst case. So we need to have a weighted tree solution:
public class WeightedConnectDb {
ArrayList array;
ArrayList depths;
public WeightedConnectDb(int n) {
array = new ArrayList<>(n);
depths = new ArrayList<>(n);
for (int i = 0; i < n; ++i) {
array.add(i, i);
depths.add(i, 1);
}
}
public void connect(int o1, int o2) {
int root1 = findRoot(o1);
int root2 = findRoot(o2);
if (depths.get(o1) >= depths.get(o2)) {
array.set(root2, array.get(root1));
depths.set(o1, depths.get(o1) + depths.get(o2));
} else {
array.set(root1, array.get(root2));
depths.set(o2, depths.get(o2) + depths.get(o1));
}
}
private int findRoot(int x) {
int root = x;
while (array.get(root) != root) {
root = array.get(root);
}
array.set(x, root);
return root;
}
public boolean isConnected(int o1, int o2) {
int root1 = findRoot(o1);
int root2 = findRoot(o2);
return root1 == root2;
}
public void connectAll(List list) {
for (int i = 0; i < list.size() - 1; ++i) {
connect(i, i + 1);
}
}
}
The query implementation is the same of course, but for the weighted tree we make comparisons to make the tree balanced. We have an additional array that keeps track of the depths of the roots. There is also a tehcnique called path compression applied in the findRoot method.
After the improvements the runtime complexity of the algorithm is very close to O(n).
For naive solutions the complexity is O(n^2)
Comments