SPARK-SQL-之UDF、UDAF
1、UDF使用
spark.udf.register("prefix1", (name: String) => {
"Name:" + name
})
spark.sql("select *,prefix1(name) from users").show()
2、UDAF使用
2.1 弱类型
package com.shufang.rdd_ds_df
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
class MyUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
override def dataType: DataType = LongType
override def deterministic: Boolean = {
true
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
spark.udf.register("ageAvg", new MyUDAF)
spark.sql("select ageAvg(id) as av from users").show()
2.2 强类型(spark 3.0.0之后推荐使用)
package com.shufang.rdd_ds_df
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
case class Buff(var total:Long ,var count:Long)
class MyUDAF1 extends Aggregator[Long,Buff,Long] {
override def zero: Buff = Buff(0L,0L)
override def reduce(b: Buff, a: Long): Buff = {
b.count +=1
b.total += a
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.count = b1.count + b2.count
b1.total = b1.total + b2.total
b1
}
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
spark.udf.register("ageAvg", functions.udaf(new MyUDAF1()))
spark.sql("select ageAvg(id) as av from users").show()
2.3 早期版本使用强类型UDAF
如果是3.0.0之前的版本需要使用强类型,需要结合DSL sparkSQL的领域语言
package com.shufang.rdd_ds_df
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
class MyUDAF2 extends Aggregator[User,Buff,Long] {
override def zero: Buff = Buff(0L,0L)
override def reduce(b: Buff, a: User): Buff = {
b.count +=1
b.total += a.id
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.count = b1.count + b2.count
b1.total = b1.total + b2.total
b1
}
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
val column: TypedColumn[User, Long] = new MyUDAF2().toColumn
val ds: Dataset[User] = df.as[User]
ds.select(column).show()
|