pyspark 大表和小表join,使用广播变量,并对广播变量更新
from pyspark import SparkConf, SparkContext, SQLContext
# import org.apache.spark.sql.functions.broadcast
# 给定节点,根据其信息进行扩展,找到各个层级扩展的节点。
# df_edges_all:全部的(边)[source,target,xxx,xx],loop_nodes:[nodes],loop_num :循环次数, sqlContext: SqlContext的实例
def extend_fun(sc,sqlContext, df_edges_all,loop_nodes, loop_num):
# loop_nodes = ""
# 进行注册
df_edges_all.createTempView("nodes_edges_all")
loop_nodes.createTempView("loop_nodes")
#广播变量
# 进行扩展
for i in range(loop_num):
# 广播变量更新
print(i)
bd = set()
brodcast = sc.broadcast("")
for index, col in loop_nodes.toPandas().iterrows():
bd.add(col["nodes"])
brodcast.unpersist(False)
brodcast = sc.broadcast(bd)
df_edges_all_filter = df_edges_all.filter(
df_edges_all["Source"].isin(brodcast.value) | df_edges_all["Target"].isin(brodcast.value))
df_edges_all_filter.registerTempTable("loop_edges")
# df_edges_all_filter.show(10)
if i < loop_num-1:
# print(i)
loop_nodes = sqlContext.sql("""
select distinct nodes from
(select Source nodes from loop_edges where Source is not null and Source<>''
union
select Target nodes from loop_edges t where Target is not null and Target<>'')
""")
# print(loop_nodes.count())
# loop_nodes.registerTempTable("")
sqlContext.sql("select count(1) from loop_edges").show(10)
return df_edges_all_filter
if __name__ == '__main__':
conf = SparkConf().setMaster('local[2]').set("spark.executor.memory", "3g")
sc = SparkContext.getOrCreate(conf)
sqlContext = SQLContext(sc)
# 获取所有的边
df_edges_all = sqlContext.read.option("header", "true").option("inferSchema", "true").csv("E:/AAA-2/过程/(边)1.csv")
# df_edges_all.createTempView("nodes_edges_all")
# 获取初始的成员
df_nodes = sqlContext.read.option("header", "true").option("inferSchema", "true").csv("E:/AAA-2/aaa.csv")
# df_nodes.show(10)
# loop_num = sc.accumulator(3)
df_edges_all_filter = extend_fun(sc,sqlContext,df_edges_all,df_nodes,5)
df_edges_all_filter.show(10)
sc.stop()
|