13.5 Binary Search Tree
A Binary Search Tree
(BST) is a special type of binary tree: for every parent node, all nodes in its left subtree have values less than or equal to the parent's value, and all nodes in its right subtree have values greater than or equal to the parent's value. Therefore, for a BST, we can determine whether a value exists in time: starting from the root, move to the left if the current node's value is greater than the target value, or move to the right if it's smaller. Additionally, since a BST is ordered, an in-order traversal of the tree results in a sorted array.
For example, the tree below is a BST, and its in-order traversal yields [1, 2, 3, 4, 5, 6].
4
/ \
2 5
/ \ \
1 3 6
99. Recover Binary Search Tree
Problem Description
Given a binary search tree, where two nodes have been swapped by mistake, restore the tree to its correct structure.
Input and Output Example
The input is a binary search tree where two nodes have been swapped, and the output is the corrected tree.
Input:
3
/ \
1 4
/
2
Output:
2
/ \
1 4
/
3
In this example, 2 and 3 were swapped by mistake.
Solution Explanation
We can perform an in-order traversal of the BST and use a prev
pointer to keep track of the previous node during traversal. If the current node's value is less than the prev
node's value, it indicates an order mismatch that needs correction.
- If there's only one mismatch in the entire traversal, it means the two swapped nodes are adjacent.
- If there are two mismatches, the two non-adjacent nodes are swapped and need correction.
- C++
- Python
// Auxiliary function
void inorder(TreeNode* root, TreeNode*& mistake1, TreeNode*& mistake2,
TreeNode*& prev) {
if (root == nullptr) {
return;
}
inorder(root->left, mistake1, mistake2, prev);
if (prev != nullptr && root->val < prev->val) {
if (mistake1 == nullptr) {
mistake1 = prev;
}
mistake2 = root;
}
prev = root;
inorder(root->right, mistake1, mistake2, prev);
}
// Main function
void recoverTree(TreeNode* root) {
TreeNode *mistake1 = nullptr, *mistake2 = nullptr, *prev = nullptr;
inorder(root, mistake1, mistake2, prev);
if (mistake1 != nullptr && mistake2 != nullptr) {
swap(mistake1->val, mistake2->val);
}
}
# Auxiliary function
# In Python, it is not straightforward to pass pointers directly in auxiliary functions. Therefore, we use a list of length 1 to simulate passing by reference.
def inorder(
root: Optional[TreeNode],
mistake1=List[Optional[TreeNode]],
mistake2=List[Optional[TreeNode]],
prev=List[Optional[TreeNode]],
):
if root is None:
return
inorder(root.left, mistake1, mistake2, prev)
if prev[0] is not None and root.val < prev[0].val:
if mistake1[0] is None:
mistake1[0] = prev[0]
mistake2[0] = root
prev[0] = root
inorder(root.right, mistake1, mistake2, prev)
# Main function
def recoverTree(root: Optional[TreeNode]) -> None:
mistake1, mistake2, prev = [None], [None], [None]
inorder(root, mistake1, mistake2, prev)
if mistake1[0] is not None and mistake2[0] is not None:
mistake1[0].val, mistake2[0].val = mistake2[0].val, mistake1[0].val
669. Trim a Binary Search Tree
Problem Description
Given a binary search tree and two integers L and R, where L < R, trim the tree so that all nodes in the tree have values within the range [L, R].
Input and Output Example
Input is a binary search tree and two integers L and R. Output is the trimmed binary search tree.
Input: L = 1, R = 3, tree =
3
/ \
0 4
\
2
/
1
Output:
3
/
2
/
1
Solution Explanation
By leveraging the properties of a binary search tree, we can efficiently solve this problem using recursion.
- C++
- Python
TreeNode* trimBST(TreeNode* root, int low, int high) {
if (root == nullptr) {
return root;
}
if (root->val > high) {
return trimBST(root->left, low, high);
}
if (root->val < low) {
return trimBST(root->right, low, high);
}
root->left = trimBST(root->left, low, high);
root->right = trimBST(root->right, low, high);
return root;
}
def trimBST(
root: Optional[TreeNode], low: int, high: int
) -> Optional[TreeNode]:
if root is None:
return None
if root.val > high:
return trimBST(root.left, low, high)
if root.val < low:
return trimBST(root.right, low, high)
root.left = trimBST(root.left, low, high)
root.right = trimBST(root.right, low, high)
return root