# BST Diameter

Earlier this morning I read a post “Binary Tree: The Diameter” by David Pynes. The post showed up on Medium but was initially written for Towards Data Science. After quickly reading the post I decided to spend time reading the article in my computer, thinking about it and writing a solution in Java. The original article used C++.

If I am not able to get this done in a block (I practice Deep Work and use 2-hour blocks) I will take a peek at the code this afternoon.

To start we need a formal definition for the diameter of a tree given that the code should produce such number. I also want to note that the article called for a binary tree not a BST but the illustrations and animations are for BSTs. In practice I tend to find more practical uses for BSTs so I will develop a solution using them.

I did an Internet search on binary tree diameter and the top Google results pointed to GeeksforGeeks. The definition there states:

“The diameter of a tree (sometimes called the width) is the number of nodes on the longest path between two end nodes.”

Given this definition, the diameter of a BST appears to be the longest distance measured in edges, between nodes that have a common parent; that is, they are children from a single node. That said; we might need to locate such nodes and compute the individual distances to the root. That seems to be referred to as the height of each node.

Before going after a solution, the approach suggested by GeeksforGeeks provides the following definition:

The diameter of a tree T is the largest of the following quantities:

* the diameter of T’s left subtree

* the diameter of T’s right subtree

* the longest path between leaves that goes through the root of T (this can be computed from the heights of the subtrees of T)

OK, so far so good. As usual I will attempt to implement a solution using Test Driven Development (TDD). I will see if I use some code from a previous solution for BSTs. That was I could reuse the BST implementation.

Let’s first take a look at a diagram of a BST that was generated by the solution.

We start on the left side of the tree and start going up. When we reach node 4 we have the choice of selecting 1 or 2. Given that we are looking at the largest diameter we choose 2 and continue our way up.

We then move to the right branch and start going up. When we reach node 9 we have the choice of choosing 1 or 2. We choose 2, increment by 1 and keep on going up.

When we reach the root we have 4 on the left branch and 5 on the right. The dimension is 4 + 5 = 9.

Let’s now look at a second diagram.

We follow a similar route from the bottom up and we end up with the value of 2 on the left side of node 47 and with 5 on the right side. The dimension is equal to 2 + 5 = 7. Note that we did not touch the root of the BST.

Let’s now take a look at the Node class.

```package bst.canessa;

/*
* Binary tree node
*/
public class Node {

// **** members ****
int		val;
Node	left;
Node	right;

Node	parent;
int		depth;

// **** constructor ****
public Node(int val) {
this.val 	= val;
this.depth 	= 0;
left = right = null;
}

// ***** constructor ****
public Node(int val, int depth) {
this.val 	= val;
this.depth 	= depth;
left = right = null;
}

// **** get the value of this node ****
public int getVal() {
return this.val;
}

// **** get the depth of this node ****
public int getDepth() {
return this.depth;
}

// **** get the left child for this node ****
public Node getLeftChild() {
return this.left;
}

// **** set the left child on this node ****
public void setLeftChild(Node left) {
this.left = left;
}

// **** get the right child for this node ****
public Node getRightChild() {
return this.right;
}

// **** set the right child for this node ****
public void setRightChild(Node right) {
this.right = right;
}

// **** return the contents of the node as a string ****
public String toString() {
return "val: " + getVal();
}
}
```

Please note that we will not use the parent or the depth in the algorithm used to find the dimension. Until recently, I have been rewriting code from scratch each time I implement a data structure. It seems like repeating the same task over and over (i.e., creating a Node class to be used in a tree) is a waste of time. We can just reuse classes. In this example I was going to write the solution code in a separate solution which I had already created; but I forgot and added to the one you will see shortly. Sorry about the clutter.

Now let’s look at the BinaryTree class:

```package bst.canessa;

import java.util.Deque;

/*
*
*/
public class BinaryTree {

private Node 	root;
private	int		diameter;

// **** constructor ****
public BinaryTree() {
this.root = null;
}

// **** ****
public Node getRoot() {
return this.root;
}

// **** insert node ****
public void insert(int val) {
int depth = 0;

if (this.root == null) {
this.root = new Node(val, depth);
}
else {
insert(this.root, val, ++depth);
}
}

// **** insert node ****
private void insert(Node node, int val, int depth) {

// **** insert to the left ****
if (val < node.getVal()) {
if (node.getLeftChild() == null) {
node.setLeftChild(new Node(val, depth));
} else {
insert(node.getLeftChild(), val, ++depth);
}
}

// **** insert to the right ****
else {
if (node.getRightChild() == null) {
node.setRightChild(new Node(val, depth));
} else {
insert(node.getRightChild(), val, ++depth);
}
}
}

// **** in order traversal ****
void inorder(Node node) {

// **** base case (to end recursion) ****
if (node == null)
return;

// **** recursive rule(s) ****
inorder(node.getLeftChild());
System.out.print(node.getVal() + ":" + node.getDepth() + " ");
inorder(node.getRightChild());
}

// **** port order traversal ****
void postorder(Node node) {

// **** base case (to end recursion) ****
if (node == null)
return;

// **** recursive rule(s) ****
postorder(node.getLeftChild());
postorder(node.getRightChild());
System.out.print(node.getVal() + ":" + node.getDepth() + " ");
}

// **** max depth of BST ****
public int maxDepth2(Node node) {

// **** base case ****
if (node == null)
return 0;

// **** recursive rule(s) ****
int maxDepth = node.getDepth();
//		System.out.println("maxDepth2 <<< maxDepth: " + maxDepth);

int depthLeft = maxDepth2(node.getLeftChild());
//		System.out.println("maxDepth2 <<< depthLeft: " + depthLeft);
maxDepth = Math.max(maxDepth, depthLeft);

int depthRight = maxDepth2(node.getRightChild());
//		System.out.println("maxDepth2 <<< depthRight: " + depthRight); maxDepth = Math.max(maxDepth, depthRight); // **** return the max depth of the current BST **** return maxDepth; } // **** find distance from root to the node with the specified value **** public int distRootToVal(Node root, int val) { // **** base case (to end recursion) **** if (root == null) return -1; // **** initialize the distance **** int dist = -1; // **** check if value is in this node or in one of it's children **** if ((root.getVal() == val) || (dist = distRootToVal(root.getLeftChild(), val)) >= 0 ||
(dist = distRootToVal(root.getRightChild(), val)) >= 0)
return dist + 1;

// **** ****
return dist;
}

// **** find lca (Lowest Common Ancestor) ****
public Node lca(Node root, int val1, int val2, int rootDist) {

// **** base case (to end recursion) ****
if (root == null)
return null;

// **** ****
if (root.getVal() == val1)
return root;

if (root.getVal() == val2)
return root;

// **** ****
Node leftLCA  = lca(root.getLeftChild(), val1, val2, rootDist + 1);
Node rightLCA = lca(root.getRightChild(), val1, val2, rootDist + 1);

// **** ****
if ((leftLCA  != null) &&
(rightLCA != null))
return root;

// **** ****
return (leftLCA != null) ? leftLCA : rightLCA;
}

// **** find the distance between the specified nodes ****
public int distNodes(Node root, int dist1, int dist2, Node lca) {
int dist = 0;

// **** ****
if ((dist1 != -1) &&
(dist2 != -1)) {
dist = dist1 + dist2 - (2 * distRootToVal(root, lca.getVal()));
}

// **** ****
return dist;
}

// **** find distance between nodes in a BST ****
public int findDist(Node root, int val1, int val2) {
int dist = -1;

// **** step 1: find distance from root to val1 ****
int dist1 = distRootToVal(root, val1);
if (dist1 == -1)
return -1;
System.out.println("findDist <<< dist1: " + dist1);

// **** step 2: find distance from root to val2 ****
int dist2 = distRootToVal(root, val2);
if (dist2 == -1)
return -1;
System.out.println("findDist <<< dist2: " + dist2);

// **** step 3: find distance from root to LCA (Lowest Common Ancestor) ****
Node lca = lca(root, val1, val2, 0);
if (lca == null)
return -1;
System.out.println("findDist <<< lca = " + lca.toString());

// **** step 4: find the distance between nodes ****
dist = distNodes(root, dist1, dist2, lca);
System.out.println("findDist <<< dist: " + dist);

// **** ****
return dist;
}

// **** get the max depth of the tree (recursive method) ****
public int maxDepth(Node root) {

// **** base case ****
if (root == null) {
return -1;
}

// **** recursive rule(s) ****
else {

//			System.out.println("maxDepth <<< val: " + root.getVal());

// **** compute the depth of each subtree ****
int leftDepth  = maxDepth(root.left);
int rightDepth = maxDepth(root.right);

// **** use the largest value ****
//			System.out.println("maxDepth <<< leftDepth: " + leftDepth + " rightDepth: " + rightDepth); if (leftDepth > rightDepth)
return (leftDepth + 1);
else
return (rightDepth + 1);
}
}

// **** depth first traversal (non-recursive) ****
public void depthfirst(Node root) {

// **** double ended queue ****

// **** add the tree root ****

// **** ****
while (!q.isEmpty()) {

// **** pop the head mode ****
Node n = q.pop();

// **** display the node ****
System.out.print(n.getVal() + ":" + n.getDepth() + " ");

// **** insert left child ****
if (n.getLeftChild() != null)

// **** insert right child ****
if (n.getRightChild() != null)
}
}

// **** display the tree ****
public void display(Node root) {
}

// **** compute and return the diameter of the specified tree ****
public int diameter(Node root) {
this.diameter = 0;
findDiameter(root);
return this.diameter;
}

// **** compute diameter of the specified tree (post order traversal) ****
private int findDiameter(Node node) {

// **** base case ****
if (node == null)
return 0;

// **** recursive rule(s) ****
int left  = findDiameter(node.getLeftChild());
int right = findDiameter(node.getRightChild());

// **** update diameter ****
if ((left + right) > this.diameter)
this.diameter = left + right;

//		System.out.println(	"findDiameter <<< val: " + node.getVal() + " left: " + left + " right: " + right + " diameter: " + this.diameter); // **** return the largest diameter plus one **** if (left > right)
return 1 + left;
else
return 1 + right;
}
}
```

In this class I have added the diameter member. The reason for it is that in Java arguments are passed by value.

Java is Pass by Value and Not Pass by Reference

We need to update this variable and we could have declared an outside variable but using a member seems somewhat more elegant. We also can protect it as needed. Note that in our code we declare both class members private.

Of interest are the diameter() and findDiameter() methods. I started with different names but changed them to loosely match the C++ code in David’s post.

The diameter() method is public so it can be called by the solution code. It sets the diameter member and calls findDiameter() which is a recursive and private call. As you will see, I use other convenience functions, some of which are in this class and others in the solution.

Now let’s take a look at the clutter solution code:

```package bst.canessa;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;

/*
*
*/
public class Solution {

// **** generate a range within the specified value ****
final static int MAX_LEN	= 31;
final static int MAX_VAL	= 200;		// was: 16;
final static int SEED		= 1;		// was: 12;

/*
* Generate an array of random numbers in the specified range.
*/
static int[] arrayWithUniqueValues(int from, int to) throws Exception {

System.out.println("arrayWithUniqueValues <<< from: " + from + " to: " + to);

// **** determine the number of integers to generate ****
int n 	= to - from + 1;

System.out.println("arrayWithUniqueValues <<<    n: " + n);

// **** declare and populate an array with the numbers in ascending order ****
int arr[] = new int[n];
for (int i = 0; i < n; i++)
arr[i] = from + i;

// **** declare and populate the random array ****
int[] rarr 			= new int[n];
int x				= n;

//		SecureRandom srand 	= new SecureRandom();
Random rand = new Random(SEED);

for (int i = 0; i < n; i++) {
int k 	= rand.nextInt(x);
rarr[i] = arr[k];
arr[k] 	= arr[x - 1];
x--;
}

// **** check if we have a duplicate random number ****
HashSet<Integer> set = new HashSet<Integer>();
for (int i = 0; i < rarr.length; i++)
throw new Exception("arrayWithUniqueValues <<< duplicate: " + rarr[i]);
if (set.size() != n)
throw new Exception("arrayWithUniqueValues <<< set.size() != n: " + n);

// **** display the random array ****
System.out.print("arrayWithUniqueValues <<< rarr: [");
for (int i = 0; i < rarr.length; i++) {
System.out.print(rarr[i]);
if (i < rarr.length - 1)
System.out.print(", ");
}
System.out.println("]");

// **** return the random array ****
return rarr;
}

/*
* Populate array with integers in the specified range
* using a random order.
*/
static int[] populateArray(int minVal, int maxVal) {

// **** populate the list ****
List<Integer> lst = new ArrayList<Integer>();
for (int i = 0; i < (maxVal - minVal + 1); i++) {
}

// **** display the ordered list ****
System.out.print("populateArray <<<  lst: ");
for (int val : lst)
System.out.print(val + " ");
System.out.println();

// **** shuffle the list ****
Collections.shuffle(lst);

// **** convert the shuffled list to an array ****
int[] arrs = new int[lst.size()];

// **** populate the shuffled array ****
for (int i = 0; i < lst.size(); i++)
arrs[i] = lst.get(i);

// **** display the shuffled array ****
System.out.print("populateArray <<< arrs: ");
for (int i = 0; i < arrs.length; i++)
System.out.print(arrs[i] + " ");
System.out.println();

// **** return the shuffled array ****
return arrs;
}

/*
*
*/
static int[] randomArray(Random rand) {

int n 		= rand.nextInt(MAX_LEN);

System.out.println("randomArray <<<   n: " + n);

int[] arr 	= new int[n];

// **** populate the array ****
for (int i = 0; i < n; i++)
arr[i] = rand.nextInt(MAX_VAL);

// **** display the array ****
System.out.print("randomArray <<< arr: ");
for (int i = 0; i < arr.length; i++)
System.out.print(arr[i] + " ");
System.out.println();

// **** ****
return arr;
}

/*
* Return an array out of four based on the argument.
*/
static int[] dimArrays(int n) {

int[] arr = null;

// **** populate the array ****
switch (n) {
case 0:
int[] arr0 = {100, 150, 50, 75, 125, 74, 73, 76, 124, 130, 131};
arr = arr0;
break;

case 1:
int[] arr1 = {100, 50, 150, 25, 75, 125, 175, 12, 70, 78, 123, 128, 225, 5, 65, 73, 79, 122, 126, 129, 250, 0, 275};
arr = arr1;
break;

case 2:
int[] arr2= {100, 75, 125, 50, 87, 112, 150, 25, 175};
arr = arr2;
break;

case 3:
int[] arr3= {100, 125, 75, 80, 85, 90, 70, 65, 60};
arr = arr3;
break;

default:
System.err.println("dimArrays <<< UNEXPECTED n: " + n);
break;
}

// **** display the array ****
System.out.print("dimArrays <<<   arr: ");
for (int i = 0; i < arr.length; i++)
System.out.print(arr[i] + " ");
System.out.println();

// **** ****
return arr;
}

/*
* Build a BST and perform some operations on it.
*/
public static void main(String[] args) throws Exception {

// **** instantiate and initialize a pseudo random number generator ****
Random rand			= new Random(SEED);

//		int firstInt		= rand.nextInt();
//		System.out.println("main <<< firstInt: " + firstInt);

//		// **** select a range for the values ****
//		int[] range 		= {rand.nextInt(MAX_VAL) + 1, rand.nextInt(MAX_VAL) + 1};
//		System.out.println("main <<< range: [" + range[0] + ", " + range[1] + "]");

//		// **** generate an array of unique integers in the specified range ****
//		int[] values = arrayWithUniqueValues(Math.min(range[0], range[1]),
//											 Math.max(range[0], range[1]));
//
//		int[] values = populateArray(Math.min(range[0], range[1]),
//									 Math.max(range[0], range[1]));

//		// **** generate a random array in the specified range ****
//		int[] values = randomArray(rand);

// **** loop once per set of numbers ****
int[] values = null;
for (int i = 0; i < 5; i++) {

// **** use a specific array ****
if (i < 4)
values = dimArrays(i);

// **** use a pseudo random array ****
else
values = randomArray(rand);

// **** use the pseudo random array to create a BST ****
BinaryTree tree = new BinaryTree();
for (int val : values) {
tree.insert(val);
}

//			// **** populate the binary tree by hand ****
//			BinaryTree tree = new BinaryTree();
//			tree.root = new Node(1);
//			tree.root.left = new Node(2);
//			tree.root.right = new Node(3);
//			tree.root.left.left = new Node(4);
//			tree.root.left.right = new Node(5);
//
//			// **** display the value of the root node****
//			System.out.println("main <<<   root.val: " + tree.getRoot().getVal());
//			System.out.println();

// **** in order traversal ****
System.out.print("main <<<    inorder: ");
tree.inorder(tree.getRoot());
System.out.println();

// **** depth first traversal ****
System.out.print("main <<< depthfirst: ");
tree.depthfirst(tree.getRoot());
System.out.println();

// **** post order traversal ****
System.out.print("main <<<  postorder: ");
tree.postorder(tree.getRoot());
System.out.println();

// **** get the max depth of the BST ****
int maxDepth = tree.maxDepth(tree.getRoot());
System.out.println("main <<<   maxDepth: " + maxDepth);

// **** get the max depth of the tree ****
maxDepth = tree.maxDepth2(tree.getRoot());
System.out.println("main <<<  maxDepth2: " + maxDepth);

// **** get the diameter of the tree ****
System.out.println("main <<<   diameter: " + tree.diameter(tree.getRoot()));

// **** for the looks ****
System.out.println();
}

// **** display the BST tree ****
//		tree.display(tree.getRoot());

// ???? exit the program ????
System.exit(0);

/*
// **** generate two values to determine the distance ****
int val1 = Math.min(range[0], range[1]) + rand.nextInt(MAX_VAL / 3);
int val2 = Math.max(range[0], range[1]) - rand.nextInt(MAX_VAL / 3);
System.out.println("main <<< val1: " + val1 + " val2: " + val2);

// **** step 1: find distance from root to val1 ****
int dist1 = tree.distRootToVal(tree.getRoot(), val1);
System.out.println("main <<< dist1: " + dist1);

// **** step 2: find distance from root to val2 ****
int dist2 = tree.distRootToVal(tree.getRoot(), val2);
System.out.println("main <<< dist2: " + dist2); // **** step 3: find distance from root to LCA (Lowest Common Ancestor) **** Node lca = null; if ((dist1 >= 0) &&
(dist2 >= 0)) {
lca = tree.lca(tree.getRoot(), val1, val2, 0);
if (lca != null)
System.out.println("main <<< lca = " + lca.toString());
else
System.out.println("main <<< lca == null");
}

// **** step 4: find the distance between nodes ****
if (lca != null) {
int dist = tree.distNodes(tree.getRoot(), dist1, dist2, lca);
System.out.println("main <<< dist: " + dist);
}

// **** putting all the steps together ****
int dist = tree.findDist(tree.getRoot(), val1, val2);
System.out.println("main <<< val1: " + val1 + " val2: " + val2 + " dist: " + dist);
*/
}

}
```

In earlier posts I have used the functions which are now commented at the start of the main() function. Please disregard them. I also used them while developing this code.

As you can see there is a loop that starts by using the dimArrays() and randomArray() functions. I worked the code using the second function. Towards the end, I decided to verify my solution with the test cases in David’s code. He did a good job on providing the arrays and results. As you can see I used the four arrays and the results seem to match.

Now let’s look at the results in the console:

```dimArrays <<<   arr: 100 150 50 75 125 74 73 76 124 130 131
main <<<    inorder: 50:1 73:4 74:3 75:2 76:3 100:0 124:3 125:2 130:3 131:4 150:1
main <<< depthfirst: 100:0 50:1 150:1 75:2 125:2 74:3 76:3 124:3 130:3 73:4 131:4
main <<<  postorder: 73:4 74:3 76:3 75:2 50:1 124:3 131:4 130:3 125:2 150:1 100:0
main <<<   maxDepth: 4
main <<<  maxDepth2: 4
main <<<   diameter: 8

dimArrays <<<   arr: 100 50 150 25 75 125 175 12 70 78 123 128 225 5 65 73 79 122 126 129 250 0 275
main <<<    inorder: 0:5 5:4 12:3 25:2 50:1 65:4 70:3 73:4 75:2 78:3 79:4 100:0 122:4 123:3 125:2 126:4 128:3 129:4 150:1 175:2 225:3 250:4 275:5
main <<< depthfirst: 100:0 50:1 150:1 25:2 75:2 125:2 175:2 12:3 70:3 78:3 123:3 128:3 225:3 5:4 65:4 73:4 79:4 122:4 126:4 129:4 250:4 0:5 275:5
main <<<  postorder: 0:5 5:4 12:3 25:2 65:4 73:4 70:3 79:4 78:3 75:2 50:1 122:4 123:3 126:4 129:4 128:3 125:2 275:5 250:4 225:3 175:2 150:1 100:0
main <<<   maxDepth: 5
main <<<  maxDepth2: 5
main <<<   diameter: 10

dimArrays <<<   arr: 100 75 125 50 87 112 150 25 175
main <<<    inorder: 25:3 50:2 75:1 87:2 100:0 112:2 125:1 150:2 175:3
main <<< depthfirst: 100:0 75:1 125:1 50:2 87:2 112:2 150:2 25:3 175:3
main <<<  postorder: 25:3 50:2 87:2 75:1 112:2 175:3 150:2 125:1 100:0
main <<<   maxDepth: 3
main <<<  maxDepth2: 3
main <<<   diameter: 6

dimArrays <<<   arr: 100 125 75 80 85 90 70 65 60
main <<<    inorder: 60:4 65:3 70:2 75:1 80:2 85:3 90:4 100:0 125:1
main <<< depthfirst: 100:0 75:1 125:1 70:2 80:2 65:3 85:3 60:4 90:4
main <<<  postorder: 60:4 65:3 70:2 90:4 85:3 80:2 75:1 125:1 100:0
main <<<   maxDepth: 4
main <<<  maxDepth2: 4
main <<<   diameter: 6

randomArray <<<   n: 13
randomArray <<< arr: 188 47 113 54 104 34 6 178 148 169 73 117 63
main <<<    inorder: 6:3 34:2 47:1 54:3 63:6 73:5 104:4 113:2 117:5 148:4 169:5 178:3 188:0
main <<< depthfirst: 188:0 47:1 34:2 113:2 6:3 54:3 178:3 104:4 148:4 73:5 117:5 169:5 63:6
main <<<  postorder: 6:3 34:2 63:6 73:5 104:4 54:3 117:5 169:5 148:4 178:3 113:2 47:1 188:0
main <<<   maxDepth: 6
main <<<  maxDepth2: 6
main <<<   diameter: 7
```

We first display the array used to construct the BST. We then traverse each tree in order to easily check that the tree was properly created. The method displays the value followed by a “:” and the depth of the node. The first entry in the arrays is used as a root so it is always at level zero (0).

After the inorder() call, we do a depth first traversal. That was done to help me verify that the tree diagrams I generated on paper and Visio were correct.

The post order method displays the tree in post order. If you enable the print statement in the findDiameter() method in the BinaryTree.java class you will see that the tree traversal is done in post order. That helps understand and visualize how things work.

The reason I displayed the depth of the tree twice was to check that both returned the same values. I could have commented that out because I had no use for them.

Finally the diameter for each tree is displayed.

Hope you enjoyed this post and most important try developing the code in the language of your choice. Then go back to the posts and verify that all is well. Of course, if you have a comment or question regarding this or any other post, please do not hesitate and leave me a message below.

Keep on reading and practicing. Seem like is the only way to learn and move ahead.

Regards;

John