UDF(User-Defined-Function)
UDF是用于处理一行数据的,接受一行输入产生一个输出,类似与map()算子,
UDAF(User- Defined Aggregation Funcation)
UDAF用于接收一组输入数据然后产生一个输出结果。
UDAF需要使用继承UserDefinedAggregateFunction的自定义类来实现功能,UserDefinedAggregateFunction中提供了8个抽象方法来帮助我们实现UDAF的构建。
public StructType inputSchema()
:
用于指定UDAF所输入数据的schmema的,也就是需要在这个方法类定义UDAF输入数据的字段的名称合字段的类型。
StructType bufferSchema()
:
因为UDAF是将数据进行聚合的,因此会使用到中间的临时变量进行数据存储,这个方法是用于定义这些中间的临时变量的Schema的。
DataType dataType()
:
这个方法是用于定义UDAF的返回结果的数据结构的。
boolean deterministic()
:
这个方法用于返回聚合函数是否是幂等的,即相同输入是否总是能得到相同输出。
为什么会有这个方法呢?这源于spark的推测执行(spark.speculation=true推测执行开启):推测执行是指对于Spark程序里面少部分运行慢的Task,会在其他节点的Executor上再次启动这个task,如果其中一个Task实例运行成功则将这个最先完成的Task的计算结果作为最终结果,同时会干掉其他Executor上运行的实例,从而加快运行速度。但是推测执行只有在函数是幂等的情况下才会这样运作,如果不是幂等的函数只会一直等待该Task执行。
void initialize(MutableAggregationBuffer buffer)
:
该方法用于初始化缓冲区的字段。
void update(MutableAggregationBuffer buffer, Row row)
:
该方法用于处理相同的executor间的数据合并,当有新的输入数据时,update用户更新缓存变量。
"void merge(MutableAggregationBuffer buffer, Row row)":
该方法用于不同excutor间已经进行初步聚合的数据进行合并。
"Object evaluate(Row row)":
通过前面的缓冲区完成聚合后,在这个方法里对聚合的字段进行最终的运算。
实例:
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
public class MyUDAF extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyUDAF() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.DoubleType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.DoubleType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
//1、该聚合函数的输入参数的数据类型
public StructType inputSchema() {
return inputSchema;
}
//2、聚合缓冲区中的数据类型.(有序性)
public StructType bufferSchema() {
return bufferSchema;
}
//3、返回值的数据类型
public DataType dataType() {
return DataTypes.DoubleType;
}
//4、这个函数是否总是在相同的输入上返回相同的输出,一般为true
public boolean deterministic() {
return true;
}
//5、初始化给定的聚合缓冲区,在索引值为0的sum=0;索引值为1的count=1;
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0D);
buffer.update(1, 0D);
}
//6、更新
public void update(MutableAggregationBuffer buffer, Row input) {
//如果input的索引值为0的值不为0
if (!input.isNullAt(0)) {
double updateSum = buffer.getDouble(0) + input.getDouble(0);
double updateCount = buffer.getDouble(1) + 1;
buffer.update(0, updateSum);
buffer.update(1, updateCount);
}
}
//7、合并两个聚合缓冲区,并将更新后的缓冲区值存储回“buffer1”
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
buffer1.update(0, mergeSum);
buffer1.update(1, mergeCount);
}
//8、计算出最终结果
public Double evaluate(Row buffer) {
return buffer.getDouble(0) / buffer.getDouble(1);
}
}
main函数:
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import java.math.BigDecimal;
public class UDAFJAVA {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("RunMyUDAF")
.master("local")
.getOrCreate();
SparkContext sc = spark.sparkContext();
sc.setLogLevel("ERROR");
// Register the function to access it
spark.udf().register("myAverage", new MyUDAF());
Dataset<Row> df = spark.read().json("D:\\02Code\\0901\\sd_demo\\src\\data\\udaf.json");
df.createOrReplaceTempView("employees");
df.show();
//保留两位小数,四舍五入
spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
@Override
public Double call(Double in) throws Exception {
BigDecimal b = new BigDecimal(in);
double res = b.setScale(2, BigDecimal.ROUND_HALF_DOWN).doubleValue();
return res;
}
}, DataTypes.DoubleType);
Dataset<Row> result = spark
.sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
result.show();
spark.stop();
}
}
udaf.json:
{"name":"Michael","salary":0}
{"name":"Andy","salary":4537}
{"name":"Justin","salary":3500.0}
{"name":"Berta","salary":0}
{"name":"Michael","salary":3000.0}
{"name":"Andy","salary":4500.0}
{"name":"Justin","salary":3500.0}
{"name":"Berta","salary":4000.0}
{"name":"Andy","salary":4500.0}
标签:JAVA,buffer,apache,UDF,org,spark,import,Spark,public
From: https://www.cnblogs.com/liuyechang/p/17002996.html