Description

Given the root of a binary tree, flatten the tree into a “linked list”:

  • The “linked list” should use the same TreeNode class where the right child pointer points to the next node in the list and the left child pointer is always null.
  • The “linked list” should be in the same order as a pre-order traversal of the binary tree.
  • Follow up: Can you flatten the tree in-place (with O(1) extra space)?

Idea

Simple solution with O(n) memory

Recursive pre-order traversal with appending to the list and then iterate and set links

Morris traversal for O(1) memory

In a pre-order traversal of a binary tree, each vertex is processed in (node, left, right) order. This means that the entire left subtree could be placed between the node and its right subtree.

To do this, however, we’ll first have to locate the last node in the left subtree. This is easy enough, since we know that the last node of a pre-order tree can be found by moving right as many times as possible from its root.

So we should be able to move through the binary tree, keeping track of the curent node (curr). Whenever we find a left subtree, we can dispatch a runner to find its last node, then stitch together both ends of the left subtree into the right path of curr, taking heed to sever the left connection at curr.

Code

Simple solution with O(n) memory

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def flatten(self, root: Optional[TreeNode]) -> None:
        """
        Do not return anything, modify root in-place instead.
        """
        if not root:
            return
            
        elems = []
        def traverse(node):
            elems.append(node)
            if node.left:
                traverse(node.left)
            if node.right:
                traverse(node.right)
        traverse(root)
        for i in range(len(elems)-1):
            elems[i].left = None
            elems[i].right = elems[i+1]

Morris traversal

class Solution:
    def flatten(self, root: TreeNode) -> None:
        curr = root
        while curr:
            if curr.left:  # Check if the current node has a left child.
                # Find the rightmost node in the left subtree.
                runner = curr.left
                while runner.right:
                    runner = runner.right
 
                # Relink the rightmost node of the left subtree to the right subtree.
                runner.right = curr.right
 
                # Move the left subtree to the right and set the left to None.
                curr.right = curr.left
                curr.left = None
 
            # Move to the next node in the "linked list" formed by the right pointers.
            curr = curr.right