函数也是接收输入并计算求值,所以也是表达式。在Expr中可以看到有三种函数类型

  • ScalarFunction:求值
  • AggregateFunction:聚合
  • WindowFunction:窗口

ScalarFunction

pub struct ScalarFunction {
    /// The function
    pub func: Arc<crate::ScalarUDF>,
    /// List of expressions to feed to the functions as arguments
    pub args: Vec<Expr>,
}

对于一个函数来讲,其核心当然就是函数的定义以及函数运行的参数了。参数就是表达式的列表,这里主要关注的就是函数的定义。

#[derive(Debug, Clone)]
pub struct ScalarUDF {
    inner: Arc<dyn ScalarUDFImpl>,
}

ScalarUDF是统一的包装,ScalarUDFImpl是trait约束,也就是函数签名用于声明函数的,实现了该trait也就实现了一个函数的定义。ScalarUDFImpl里面的函数比较多,但是对于一个函数来说,我们主要关注的就是其参数类型,返回值类型,函数执行。

pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
	// 获取函数参数的类型信息
    fn signature(&self) -> &Signature;
    // 根据输入参数的类型得出返回值类型
	fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
	// 作用和return_type一样,不过这里输入的参数信息比类型更丰富,所以实现了该方法后就不应该使用return_type方法了。该方法的默认实现也是调用上面的return_type方法。
	fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef>;
	// 执行函数获取结果
	fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
}

函数参数列表的类型信息使用Signature描述

pub struct Signature {
    /// 描述了函数接收的参数列表的类型信息
    pub type_signature: TypeSignature,
    /// 对函数运行结果变化的描述。有abs这样结果固定的,有now这样每次查询不一样的,还有random这样每次运行都不一样的函数
    pub volatility: Volatility,
}

类似与函数的重载,一个函数可以有多种接收不同个数、类型参数的重载实现和返回不同的类型值。上面的trait能够描述出函数的这种能力。

找个函数示例的实现,比如NUL函数,这个函数接收两个参数,当第一个参数的值为NULL的时候就返回第二个参数,否则返回第一个参数的值。逻辑简单,下面的代码也比较简洁,就不赘述了。

pub struct NVLFunc {
    signature: Signature,
    aliases: Vec<String>,
}
 
impl NVLFunc {
    pub fn new() -> Self {
        Self {
	        // 接收2个参数,两个参数的类型也相同。
            signature: Signature::uniform(
                2,
                SUPPORTED_NVL_TYPES.to_vec(),
                Volatility::Immutable,
            ),
            aliases: vec![String::from("ifnull")],
        }
    }
}
 
impl ScalarUDFImpl for NVLFunc {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
 
    fn name(&self) -> &str {
        "nvl"
    }
 
    fn signature(&self) -> &Signature {
        &self.signature
    }
 
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
        Ok(arg_types[0].clone())
    }
 
    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        nvl_func(&args.args)
    }
 
    fn aliases(&self) -> &[String] {
        &self.aliases
    }
}
 
fn nvl_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
    let [lhs, rhs] = take_function_args("nvl/ifnull", args)?;
    let (lhs_array, rhs_array) = match (lhs, rhs) {
        (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
            (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?)
        }
        (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
            (Arc::clone(lhs), Arc::clone(rhs))
        }
        (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
            (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs))
        }
        (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
            let mut current_value = lhs;
            if lhs.is_null() {
                current_value = rhs;
            }
            return Ok(ColumnarValue::Scalar(current_value.clone()));
        }
    };
    // 根据lhs的值是否为null来决定最终的返回值
    let to_apply = is_not_null(&lhs_array)?;
    let value = zip(&to_apply, &lhs_array, &rhs_array)?;
    Ok(ColumnarValue::Array(value))
}

AggregateFunction

pub struct AggregateFunction {
    pub func: Arc<AggregateUDF>,
    pub params: AggregateFunctionParams,
}

agg也是函数定义和运行参数两部分。AggregateUDF同样也是一个统一的对外包装。所以从组织结构上来看和ScalarFunction差不多。

pub struct AggregateUDF {
    inner: Arc<dyn AggregateUDFImpl>,
}

区别在于AggregateUDFImpl trait定义的方法有所不同。当然参数类型和返回值类型这些还是一样的,主要在于Accumulator

pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
	fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
	fn create_groups_accumulator(
        &self,
        _args: AccumulatorArgs,
    ) -> Result<Box<dyn GroupsAccumulator>> {
        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
    }
}

聚合函数需要对每个分组(group)逐批次更新状态、合并中间状态,并在最后产出聚合结果,负责这项工作的就是AccumulatorGroupsAccumulator则是可以多个分组共享的性能更好的实现。

sum这个函数为例来看聚合函数的实现,我们主要看创建的Accumulator以及其实现

fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
        if args.is_distinct {
            macro_rules! helper {
                ($t:ty, $dt:expr) => {
                    Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt)))
                };
            }
            downcast_sum!(args, helper)
        } else {
            macro_rules! helper {
                ($t:ty, $dt:expr) => {
                    Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
                };
            }
            downcast_sum!(args, helper)
        }
    }

这里使用了宏来简化代码,展开宏之后就是按照参数的类型创建对应类型的累加器罢了

match args.return_field.data_type().clone() {
    DataType::UInt64 => Ok(Box::new(DistinctSumAccumulator::<UInt64Type>::new(
        &(args.return_field.data_type().clone()),
    ))),
    DataType::Int64 => Ok(Box::new(DistinctSumAccumulator::<Int64Type>::new(
        &(args.return_field.data_type().clone()),
    ))),
    ...
    }

看一下DistinctSumAccumulator这个累加器的实现

impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
	// 更新中间结果
	fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }
 
        let array = values[0].as_primitive::<T>();
        match array.nulls().filter(|x| x.null_count() > 0) {
            Some(n) => {
	            // 数组包含null,那么只遍历non-null,将值插入哈希集合
                for idx in n.valid_indices() {
                    self.values.insert(Hashable(array.value(idx)));
                }
            }
            // 没有null则直接遍历插入
            None => array.values().iter().for_each(|x| {
                self.values.insert(Hashable(*x));
            }),
        }
        Ok(())
    }
    // 求得最后结果
	fn evaluate(&mut self) -> Result<ScalarValue> {
		let mut acc = T::Native::usize_as(0);
		// 遍历累加
		for distinct_value in self.values.iter() {
			acc = acc.add_wrapping(distinct_value.0)
		}
		let v = (!self.values.is_empty()).then_some(acc);
		ScalarValue::new_primitive::<T>(v, &self.data_type)
    }
}

逻辑的话也是比较清晰,就是使用一个哈希集合记录去重的值,最后遍历累加最终得到了去重累加的和。

窗口函数思路也是类似,这里就不单独说明了