使用spark sql 3.0版本自定义UDFA ,
3.0版本之前 extends? UserDefinedAggregateFunction? 已经过时
新方法如下代码:
代码中自定义了求和,求平均
package com.cy.spark
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, Row, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}
object DemoUDAF {
def main(args: Array[String]): Unit = {
//屏蔽日志
Logger.getLogger("org").setLevel(Level.ERROR)
val conf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[4]")
val sc = new SparkContext(conf)
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName)
.master("local[4]")
.getOrCreate()
val rdd: RDD[String] = sc.textFile("E://file/spark/student01.txt")
val stuRdd: RDD[Stu1] = rdd.map(line => {
//class01 tom 100
val split = line.split(" ")
val classess = split(0)
val name = split(1)
val score = split(2).toInt
Stu1(classess, name, score)
})
//重要
import spark.implicits._
//rdd -> df
val df: DataFrame = stuRdd.toDF
df.createOrReplaceTempView("stu")
import org.apache.spark.sql.functions._
//UDAF 求平均
val avgAgg1 = new Aggregator[Double, (Double, Int), Double] {
//初始值
override def zero: (Double, Int) = (0.0, 0)
//每个分组区局部聚合的方法,
override def reduce(b: (Double, Int), a: Double): (Double, Int) = {
(b._1 + a, b._2 + 1)
}
//全局聚合调用的方法
override def merge(b1: (Double, Int), b2: (Double, Int)): (Double, Int) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
//计算最终的结果
override def finish(reduction: (Double, Int)): Double = {
reduction._1 / reduction._2
}
//中间结果的encoder
override def bufferEncoder: Encoder[(Double, Int)] = {
Encoders.tuple(Encoders.scalaDouble, Encoders.scalaInt);
}
//返回结果的encoder
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
//UDAF 求和
val sumAgg = new Aggregator[Int,Int,Int] {
//初始值
override def zero: Int = 0
//每个分组区局部聚合的方法,
override def reduce(b: Int, a: Int): Int = b + a
//全局聚合调用的方法
override def merge(b1: Int, b2: Int): Int = b1 + b2
//计算最终的结果
override def finish(reduction: Int): Int = reduction
//中间结果的encoder
override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
//返回结果的encoder
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
//自定义
spark.udf.register("sum1", udaf(sumAgg))
val sql =
"""
|select classess, sum1(score) as score
|from stu
|group by classess
|""".stripMargin
spark.sql(sql).show()
spark.stop()
}
}
case class Stu1(classess:String, name:String, score:Int)
数据源:自己多造点
class01 tom 100
|