TreeNode讲树形结构的算法单独抽离出来了,避免了大量代码的重复。

树的操作无非就是遍历顺序以及每个节点元素的处理。在下面的API中

  1. 将遍历的顺序成对应的API
  2. 按照消费节点元素,或者转换节点元素分别提供两种API
  3. 树的遍历是结束还是终止,由处理的结果决定。 这样就将业务逻辑和树结构的逻辑分离来了。

API

pub trait TreeNode: Sized {
    fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
        &'n self,
        visitor: &mut V,
    ) -> Result<TreeNodeRecursion> {
        visitor
            .f_down(self)?
            .visit_children(|| self.apply_children(|c| c.visit(visitor)))?
            .visit_parent(|| visitor.f_up(self))
    }
 
    fn rewrite<R: TreeNodeRewriter<Node = Self>>(
        self,
        rewriter: &mut R,
    ) -> Result<Transformed<Self>> {
        handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| {
            rewriter.f_up(n)
        })
    }
    
    fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
        &'n self,
        mut f: F,
    ) -> Result<TreeNodeRecursion> {
        fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
            node: &'n N,
            f: &mut F,
        ) -> Result<TreeNodeRecursion> {
            f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
        }
 
        apply_impl(self, &mut f)
    }
 
    fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
        self,
        f: F,
    ) -> Result<Transformed<Self>> {
        self.transform_up(f)
    }
 
    fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
        self,
        mut f: F,
    ) -> Result<Transformed<Self>> {
        fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
            node: N,
            f: &mut F,
        ) -> Result<Transformed<N>> {
            f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
        }
 
        transform_down_impl(self, &mut f)
    }
 
    fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
        self,
        mut f: F,
    ) -> Result<Transformed<Self>> {
        fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
            node: N,
            f: &mut F,
        ) -> Result<Transformed<N>> {
            node.map_children(|c| transform_up_impl(c, f))?
                .transform_parent(f)
        }
 
        transform_up_impl(self, &mut f)
    }
 
    fn transform_down_up<
        FD: FnMut(Self) -> Result<Transformed<Self>>,
        FU: FnMut(Self) -> Result<Transformed<Self>>,
    >(
        self,
        mut f_down: FD,
        mut f_up: FU,
    ) -> Result<Transformed<Self>> {
        fn transform_down_up_impl<
            N: TreeNode,
            FD: FnMut(N) -> Result<Transformed<N>>,
            FU: FnMut(N) -> Result<Transformed<N>>,
        >(
            node: N,
            f_down: &mut FD,
            f_up: &mut FU,
        ) -> Result<Transformed<N>> {
            handle_transform_recursion!(
                f_down(node),
                |c| transform_down_up_impl(c, f_down, f_up),
                f_up
            )
        }
 
        transform_down_up_impl(self, &mut f_down, &mut f_up)
    }
    
    fn exists<F: FnMut(&Self) -> Result<bool>>(&self, mut f: F) -> Result<bool> {
        let mut found = false;
        self.apply(|n| {
            Ok(if f(n)? {
                found = true;
                TreeNodeRecursion::Stop
            } else {
                TreeNodeRecursion::Continue
            })
        })
        .map(|_| found)
    }
 
    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
        &'n self,
        f: F,
    ) -> Result<TreeNodeRecursion>;
 
    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
        self,
        f: F,
    ) -> Result<Transformed<Self>>;

TreeNode的API可以分为三类

  • Inspecting:用于遍历树。apply,visit,exists
  • Transforming:用于遍历并转换树节点。transform, transform_up, transform_down,transform_down_up, rewrite
  • 内部api:辅助实现TreeNode api的内部函数。apply_children, map_children

树是存在遍历顺序的。down表示前序遍历,up表示后序遍历。

Visitor模式

整个TreeNode API的实现和访问者方式紧密相关,首先来了解这个模式。visit方法和rewriter方法都是用了这个模式。也就是讲TreeNode这个树的逻辑交由访问者进行处理。

pub trait TreeNodeVisitor<'n>: Sized {
    /// The node type which is visitable.
    type Node: TreeNode;
 
    /// Invoked while traversing down the tree, before any children are visited.
    /// Default implementation continues the recursion.
    fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
        Ok(TreeNodeRecursion::Continue)
    }
 
    /// Invoked while traversing up the tree after children are visited. Default
    /// implementation continues the recursion.
    fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
        Ok(TreeNodeRecursion::Continue)
    }
}
 
pub trait TreeNodeRewriter: Sized {
    /// The node type which is rewritable.
    type Node: TreeNode;
 
    /// Invoked while traversing down the tree before any children are rewritten.
    /// Default implementation returns the node as is and continues recursion.
    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
        Ok(Transformed::no(node))
    }
 
    /// Invoked while traversing up the tree after all children have been rewritten.
    /// Default implementation returns the node as is and continues recursion.
    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
        Ok(Transformed::no(node))
    }
}

可以看到这两个trait的方法是基本一致的。只是一个返回的是TreeNodeRecursion,另一个是Transformed。并且Transformed只是因为重写会返回新的对象,其内部当中还是封装了TreeNodeRecursion的

pub struct Transformed<T> {
    pub data: T,
    pub transformed: bool,
    pub tnr: TreeNodeRecursion,
}

visit的代码实现如下,即前序遍历每一个节点,然后对每个节点应用visitor的f_down方法。

fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
	&'n self,
	visitor: &mut V,
) -> Result<TreeNodeRecursion> {
	visitor
		.f_down(self)?
		.visit_children(|| self.apply_children(|c| c.visit(visitor)))?
		.visit_parent(|| visitor.f_up(self))
}

rewriter也是如此。节点的逻辑都依赖传入的visitor,而遍历的逻辑则是在TreeNode trait中。例如TreeNodeRecursion,这个决定了节点和子节点的遍历顺序。

pub enum TreeNodeRecursion {
    /// 继续
    Continue,
    /// 跳过,按照遍历的顺序决定跳过哪些
    Jump,
    /// 结束递归
    Stop,
}

借助这个流程可以就可以实现对遍历的控制。以及树的转换。

以LogicalPlan为例,看其是如何实现TreeNode trait的。虽然上面的trait中有很多方法,但是基本上都有了默认实现,而且一般也不需要改动。业务相关的逻辑也都以闭包、visitor、rewriter的方式抽离出去了。

实际上需要实现也就是两个方法,一个用于单纯地遍历并应用闭包函数,一个是遍历并用闭包转换子节点。

遍历是比较简单的,不过转换的逻辑就比较复杂,主要是LogicalPlan的节点有些多,需要一一处理不同类型节点的转换逻辑。不过转换的逻辑同样也在闭包当中,这里就是大量的模式匹配。

impl TreeNode for LogicalPlan {
    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
        &'n self,
        f: F,
    ) -> Result<TreeNodeRecursion> {
	    // 遍历所有input
        self.inputs().into_iter().apply_until_stop(f)
    }
    
    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
        self,
        mut f: F,
    ) -> Result<Transformed<Self>> {
        Ok(match self {
            LogicalPlan::Projection(Projection {
                expr,
                input,
                schema,
            }) => rewrite_arc(input, f)?.update_data(|input| {
                LogicalPlan::Projection(Projection {
                    expr,
                    input,
                    schema,
                })
            }),
            ...
		}
	}
}