序
本文主要研究一下flink Table的AggregateFunction
实例
/** * Accumulator for WeightedAvg. */public static class WeightedAvgAccum { public long sum = 0; public int count = 0;}/** * Weighted Average user-defined aggregate function. */public static class WeightedAvg extends AggregateFunction{ @Override public WeightedAvgAccum createAccumulator() { return new WeightedAvgAccum(); } @Override public Long getValue(WeightedAvgAccum acc) { if (acc.count == 0) { return 0L; } else { return acc.sum / acc.count; } } public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum += iValue * iWeight; acc.count += iWeight; } public void retract(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum -= iValue * iWeight; acc.count -= iWeight; } public void merge(WeightedAvgAccum acc, Iterable it) { Iterator iter = it.iterator(); while (iter.hasNext()) { WeightedAvgAccum a = iter.next(); acc.count += a.count; acc.sum += a.sum; } } public void resetAccumulator(WeightedAvgAccum acc) { acc.count = 0; acc.sum = 0L; }}// register functionBatchTableEnvironment tEnv = ...tEnv.registerFunction("wAvg", new WeightedAvg());// use functiontEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
- WeightedAvg继承了AggregateFunction,实现了getValue、accumulate、retract、merge、resetAccumulator方法
AggregateFunction
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/functions/AggregateFunction.scala
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this [[AggregateFunction]]. * * @return the accumulator with the initial value */ def createAccumulator(): ACC /** * Called every time when an aggregation result should be materialized. * The returned value could be either an early and incomplete result * (periodically emitted as data arrive) or the final result of the * aggregation. * * @param accumulator the accumulator which contains the current * aggregated results * @return the aggregation result */ def getValue(accumulator: ACC): T /** * Returns true if this AggregateFunction can only be applied in an OVER window. * * @return true if the AggregateFunction requires an OVER window, false otherwise. */ def requiresOver: Boolean = false /** * Returns the TypeInformation of the AggregateFunction's result. * * @return The TypeInformation of the AggregateFunction's result or null if the result type * should be automatically inferred. */ def getResultType: TypeInformation[T] = null /** * Returns the TypeInformation of the AggregateFunction's accumulator. * * @return The TypeInformation of the AggregateFunction's accumulator or null if the * accumulator type should be automatically inferred. */ def getAccumulatorType: TypeInformation[ACC] = null}
- AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(
这几个方法中子类必须实现createAccumulator、getValue方法
) - 对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现
- 对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void
DataSetPreAggFunction
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction) extends AbstractRichFunction with GroupCombineFunction[Row, Row] with MapPartitionFunction[Row, Row] with Compiler[GeneratedAggregations] with Logging { private var output: Row = _ private var accumulators: Row = _ private var function: GeneratedAggregations = _ override def open(config: Configuration) { LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + s"Code:\n$genAggregations.code") val clazz = compile( getRuntimeContext.getUserCodeClassLoader, genAggregations.name, genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() output = function.createOutputRow() accumulators = function.createAccumulators() } override def combine(values: Iterable[Row], out: Collector[Row]): Unit = { // reset accumulators function.resetAccumulator(accumulators) val iterator = values.iterator() var record: Row = null while (iterator.hasNext) { record = iterator.next() // accumulate function.accumulate(accumulators, record) } // set group keys and accumulators to output function.setAggregationResults(accumulators, output) function.setForwardedFields(record, output) out.collect(output) } override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = { combine(values, out) }}
- DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法
GeneratedAggregations
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
abstract class GeneratedAggregations extends Function { /** * Setup method for [[org.apache.flink.table.functions.AggregateFunction]]. * It can be used for initialization work. By default, this method does nothing. * * @param ctx The runtime context. */ def open(ctx: RuntimeContext) /** * Sets the results of the aggregations (partial or final) to the output row. * Final results are computed with the aggregation function. * Partial results are the accumulators themselves. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param output output results collected in a row */ def setAggregationResults(accumulators: Row, output: Row) /** * Copies forwarded fields, such as grouping keys, from input row to output row. * * @param input input values bundled in a row * @param output output results collected in a row */ def setForwardedFields(input: Row, output: Row) /** * Accumulates the input values to the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param input input values bundled in a row */ def accumulate(accumulators: Row, input: Row) /** * Retracts the input values from the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param input input values bundled in a row */ def retract(accumulators: Row, input: Row) /** * Initializes the accumulators and save them to a accumulators row. * * @return a row of accumulators which contains the aggregated results */ def createAccumulators(): Row /** * Creates an output row object with the correct arity. * * @return an output row object with the correct arity. */ def createOutputRow(): Row /** * Merges two rows of accumulators into one row. * * @param a First row of accumulators * @param b The other row of accumulators * @return A row with the merged accumulators of both input rows. */ def mergeAccumulatorsPair(a: Row, b: Row): Row /** * Resets all the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results */ def resetAccumulator(accumulators: Row) /** * Cleanup for the accumulators. */ def cleanup() /** * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]]. * It can be used for clean up work. By default, this method does nothing. */ def close()}
- GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
AggregateUtil
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
object AggregateUtil { type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R] type JavaList[T] = java.util.List[T] //...... /** * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]]. * If all aggregation functions support pre-aggregation, a pre-aggregation function and the * respective output type are generated as well. */ private[flink] def createDataSetAggregateFunctions( generator: AggregationCodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, inputFieldTypeInfo: Seq[TypeInformation[_]], outputType: RelDataType, groupings: Array[Int], tableConfig: TableConfig): ( Option[DataSetPreAggFunction], Option[TypeInformation[Row]], Either[DataSetAggFunction, DataSetFinalAggFunction]) = { val needRetract = false val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, needRetract, tableConfig) val (gkeyOutMapping, aggOutMapping) = getOutputMappings( namedAggregates, groupings, inputType, outputType ) val aggOutFields = aggOutMapping.map(_._1) if (doAllSupportPartialMerge(aggregates)) { // compute preaggregation type val preAggFieldTypes = gkeyOutMapping.map(_._2) .map(inputType.getFieldList.get(_).getType) .map(FlinkTypeFactory.toTypeInfo) ++ accTypes val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*) val genPreAggFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", inputFieldTypeInfo, aggregates, aggInFields, aggregates.indices.map(_ + groupings.length).toArray, isDistinctAggs, isStateBackedDataViews = false, partialResults = true, groupings, None, groupings.length + aggregates.length, needRetract, needMerge = false, needReset = true, None ) // compute mapping of forwarded grouping keys val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) { val gkeyOutFields = gkeyOutMapping.map(_._1) val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1) gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2) mapping } else { new Array[Int](0) } val genFinalAggFunction = generator.generateAggregations( "DataSetAggregateFinalHelper", inputFieldTypeInfo, aggregates, aggInFields, aggOutFields, isDistinctAggs, isStateBackedDataViews = false, partialResults = false, gkeyMapping, Some(aggregates.indices.map(_ + groupings.length).toArray), outputType.getFieldCount, needRetract, needMerge = true, needReset = true, None ) ( Some(new DataSetPreAggFunction(genPreAggFunction)), Some(preAggRowType), Right(new DataSetFinalAggFunction(genFinalAggFunction)) ) } else { val genFunction = generator.generateAggregations( "DataSetAggregateHelper", inputFieldTypeInfo, aggregates, aggInFields, aggOutFields, isDistinctAggs, isStateBackedDataViews = false, partialResults = false, groupings, None, outputType.getFieldCount, needRetract, needMerge = false, needReset = true, None ) ( None, None, Left(new DataSetAggFunction(genFunction)) ) } } //......}
- AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法
小结
- AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(
这几个方法中子类必须实现createAccumulator、getValue方法
);对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现(对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void
) - DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法;GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
- AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法