在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,此处大概说一下我们需求,当业务库有数据发生变化,需要更新、插入、删除数仓中ods层的数据,因此需要改造源码。
现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入
import com.sun.corba.se.impl.activation.ServerMain.logError
import org.apache.spark.SparkContext
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.getCommonJDBCType
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.Properties
object TestInsertOrUpdateMysql {
val url: String = "jdbc:mysql://192.168.1.1:3306/test?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&autoReconnect=true&failOverReadOnly=false"
val driver: String = "com.mysql.jdbc.Driver"
val user: String = "123"
val password: String = "123"
val sql: String = "select * from testserver "
val table: String = "testinsertorupdate"
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("testSqlServer").getOrCreate()
val dbtable = "(" + sql + ") AS Temp"
val jdbcDF = spark.read
.format("jdbc")
.option("driver", driver)
.option("user", user)
.option("password", password)
.option("url", url)
.option("dbtable", dbtable)
.load()
jdbcDF.show()
//普通写入数据库
//commonWrite(jdbcDF)
//saveorupdate
insertOrUpdateToMysql("id", jdbcDF, spark)
println("======================程序结束======================")
}
1.先看一下普通的插入怎么写的
def commonWrite(jdbcDF: DataFrame): Unit = {
val properties = new Properties()
properties.put("user", user)
properties.put("password", password)
properties.put("driver", driver)
jdbcDF.write.mode(SaveMode.Append).jdbc(url, table, properties)
}
这种方式比较局限,只能做一些简单的插入(追加或覆盖等SaveMode.Append)
那么新的写法是什么呢,首先写出mysql的更新或插入的语法规则:
INSERT INTO t_name ( c1, c2, c3 )
VALUES
( 1, '1', '1')
ON DUPLICATE KEY UPDATE
c2 = '2';
需要注意的是一定要有主键,没主键没法更新;
2.看一下insertorupdate的写法
//写入数据库,批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
// 规则如下:
//没有关键主键字段即为插入,有即为更新
def insertOrUpdateToMysql(primaryKey: String, jdbcDF: DataFrame, spark: SparkSession): Unit = {
val sc: SparkContext = spark.sparkContext
spark.conf.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
//1.加载驱动程序
Class.forName(driver);
//2. 获得数据库连接
val conn: Connection = DriverManager.getConnection(url, user, password);
val tableSchema = jdbcDF.schema
val columns = tableSchema.fields.map(x => x.name).mkString(",")
val placeholders = tableSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on duplicate key update "
val update = tableSchema.fields.map(x =>
x.name.toString + "=?"
).mkString(",")
//ON DUPLICATE KEY UPDATE
//on conflict($primaryKey) do update set
val realsql = sql.concat(update)
conn.setAutoCommit(false)
val dialect = JdbcDialects.get(conn.getMetaData.getURL)
val broad_ps = sc.broadcast(conn.prepareStatement(realsql))
val numFields = tableSchema.fields.length * 2
//调用spark中自带的函数,获取属性字段与字段类型
val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val setters = tableSchema.fields.map(f => makeSetter(conn, f.dataType))
var rowCount = 0
val batchSize = 2000
val updateindex = numFields / 2
try {
jdbcDF.foreachPartition(iterator => {
//遍历批量提交
val ps = broad_ps.value
try {
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
i < updateindex match {
case true => {
if (row.isNullAt(i)) {
ps.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(ps, row, i, 0)
}
}
case false => {
if (row.isNullAt(i - updateindex)) {
ps.setNull(i + 1, nullTypes(i - updateindex))
} else {
setters(i - updateindex).apply(ps, row, i, updateindex)
}
}
}
i = i + 1
}
ps.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
ps.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
ps.executeBatch()
}
} finally {
ps.close()
}
})
conn.commit()
} catch {
case e: Exception =>
logError("Error in execution of insert. " + e.getMessage)
conn.rollback()
// insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)
} finally {
conn.close()
}
}
几个源码包
private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
}
private type JDBCValueSetter_add = (PreparedStatement, Row, Int, Int) => Unit
private def makeSetter(conn: Connection, dataType: DataType): JDBCValueSetter_add = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos - currentpos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos - currentpos))
case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos - currentpos))
case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos - currentpos))
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos - currentpos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos - currentpos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos - currentpos))
case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setString(pos + 1, row.getString(pos - currentpos))
case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - currentpos))
case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - currentpos))
case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - currentpos))
case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos - currentpos))
case _ =>
(stmt: PreparedStatement, row: Row, pos: Int, currentpos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
这里面有个属性比较关键我列出来,不加会报错 Exception in thread "main" java.io.NotSerializableException: com.mysql.jdbc.JDBC42PreparedStatement:
spark.conf.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
添加下这个配置 这个第三方序列化 用默认的javaSerializer 不行
结果如下,这块有一些坑,有需要的朋友,我们可以交流
附postgreesql的更新或插入语法:
INSERT INTO test_001 ( c1, c2, c3 )
VALUES( ?, ?, ? )
ON conflict ( ID ) DO
UPDATE SET c1=?,c2 = ?,c3 = ?;
MySQL的on duplicate key update 的使用_厄尔尼诺的夏天的博客-CSDN博客?
博主qq:907044657,欢迎大家一起交流学习,有问题请指出,转载麻烦注明出处,多谢啦
|