Given the root of a complete binary tree, return the number of the nodes in the tree.
According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.
Design an algorithm that runs in less than O(n) time complexity.
Example 1:

Input: root = [1,2,3,4,5,6]
Output: 6
Example 2:
Input: root = []
Output: 0
Example 3:
Input: root = [1]
Output: 1
Constraints:
The number of nodes in the tree is in the range [0, 5 * 104].
0 <= Node.val <= 5 * 104
The tree is guaranteed to be complete.
Brute Force Approach: DFS (O(n))
Just do any traversal and count nodes.
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
public int countNodes(TreeNode root) {
if(root ==null) return 0;
return 1+countNodes(root.left)+countNodes(root.right); // 1(root node )+left+right
}
}
❌ Why not optimal?
Optimal Approach: Use Binary Tree Height Trick (O(log²n))
💡 Key Insight:
In a complete binary tree:
-
If left height = right height, it is a perfect binary tree → total nodes = 2^h - 1
-
Otherwise, recursively count in left & right.
🔁 Steps:
-
Get leftmost height of left subtree
-
Get rightmost height of right subtree
-
If heights are equal, it's a perfect subtree → use formula
-
If not equal → recursively count nodes in left and right
class Solution {
public int countNodes(TreeNode root) {
// Base case: If tree is empty, return 0
if (root == null) return 0;
// Compute leftmost height (go only left)
int leftHeight = getLeftHeight(root);
// Compute rightmost height (go only right)
int rightHeight = getRightHeight(root);
// If both heights are equal, it is a perfect binary tree
if (leftHeight == rightHeight) {
// Total nodes in a perfect binary tree = 2^h - 1
return (1 << leftHeight) - 1;
}
// Otherwise, count recursively: 1 for root + left subtree + right subtree
return 1 + countNodes(root.left) + countNodes(root.right);
}
// Helper function to calculate height by going left
private int getLeftHeight(TreeNode node) {
int height = 0;
while (node != null) {
height++; // Increment for each level
node = node.left; // Move left
}
return height;
}
// Helper function to calculate height by going right
private int getRightHeight(TreeNode node) {
int height = 0;
while (node != null) {
height++; // Increment for each level
node = node.right; // Move right
}
return height;
}
}
Comments
Post a Comment