def main(args: Array[String]): Unit = {
val impala_db = args(2) // 查询impala库
val impala_tab = args(3) // 查询表名
val query_where = args(4) //查询条件 如Scan全表传空字符串
val LOGGER = LoggerFactory.getLogger(RimpalaDemo2.getClass)//设置日志
// jdbc url 使用kerberos认证登陆
val impalaUrl = s"jdbc:impala://host:port/${impala_db};AuthMech=1;KrbRealm=EXAMPLE.COM;KrbHostFQDN=host;KrbServiceName=impala"
// impala 驱动
val impala_driver = "com.cloudera.impala.jdbc41.Driver"
println("通过JDBC连接Kerberos环境下的Impala")
val spark = SparkUnit.getLocal("test_impala", true)
//登录Kerberos账号
//指定KDC文件路径
System.setProperty("java.security.krb5.conf", args(0))
//开启kerberos debug模式
System.setProperty("sun.security.krb5.debug", "true")
System.setProperty("javax.security.auth.useSubjectCredsOnly", "false")
//配置信息
val configuration = new Configuration()
//kerberos安全认证
configuration.set("hadoop.security.authentication", "Kerberos")
UserGroupInformation.setConfiguration(configuration)
//输入kerberos认证 Principal信息 以及keytab认证文件
UserGroupInformation.loginUserFromKeytab("USER@EXAMPLE.COM", args(1))
// 打印验证Principal信息
println(UserGroupInformation.getCurrentUser() + "------" + UserGroupInformation.getLoginUser())
//Spark 2.0版本用 getCurrentUser 1.0版本用 getLoginUser
val getCurrentUser = UserGroupInformation.getCurrentUser()
// 获取结果
val resultSet = getResultSet(getCurrentUser,impala_driver, impalaUrl,impala_tab,query_where)
// 结果集转DF并打印 默认展示20条
val df = createResultSetToDF(resultSet, spark).show()
// spark.sql("use g_scbqy_simba")
// // 将df 获取到的数据 写入 Hive表
// df.write.format("Hive").mode(SaveMode.Overwrite).insertInto("tab")
// spark.sql("select * from tab").show()
}
/**
* jdbc连接impala
* @author xxx xxx@startdt.com
* @param impalaJdbcDriver impala 驱动
* @param impalaJdbcUrl impala jdbc路径
* @param currUser 当前kerberos认证用户
* @return 连接对象
*/
def getImpalaConnection(impalaJdbcDriver: String, impalaJdbcUrl: String, currUser: UserGroupInformation): Connection = {
if (impalaJdbcDriver.length() == 0) return null
try {
Class.forName(impalaJdbcDriver).newInstance
currUser.doAs(
new PrivilegedAction[Connection] {
//重写run方法 连接impala
override def run(): Connection = DriverManager.getConnection(impalaJdbcUrl)
}
)
} catch {
case e: Exception => {
println(e.toString() + " --> " + e.getStackTraceString)
throw e
}
}
}
/**
* @authorxxx xxx@startdt.com
* 获取查询结果集
* @param currUser 当前kerberos认证用户
* @param impala_driver impala驱动
* @param impalaUrl impala jdbc链接路径
* @param tab 查询表
* @param exp 查询条件
* @return resultSet集合
*/
def getResultSet(currUser:UserGroupInformation,impala_driver:String,impalaUrl:String,tab:String,exp:String):ResultSet = {
//获取impala连接对象
val connection = getImpalaConnection(impala_driver, impalaUrl,currUser)
//判定有无传入查询条件,传入or没传 就执行相应逻辑
if(exp.length > 7) {
val resultSet = connection.createStatement().executeQuery(s"select * from ${tab} where ${exp} ")
resultSet
}else{
val resultSet = connection.createStatement().executeQuery(s"select * from ${tab} ")
resultSet
}
//返回resultSet集合
}
/**
* 根据字段类型创建结构字段
* @author xxx xxx@startdt.com
* @param name 字段名
* @param colType 字段类型
* @return
*/
def createStructField(name:String,colType:String):StructField={
colType match {
case "java.lang.String" =>{StructField(name,StringType,true)}
case "java.lang.Integer" =>{StructField(name,IntegerType,true)}
case "java.lang.Long" =>{StructField(name,LongType,true)}
case "java.lang.Boolean" =>{StructField(name,BooleanType,true)}
case "java.lang.Double" =>{StructField(name,DoubleType,true)}
case "java.lang.Float" =>{StructField(name,FloatType,true)}
case "java.sql.Date" =>{StructField(name,DateType,true)}
case "java.sql.Time" =>{StructField(name,TimestampType,true)}
case "java.sql.Timestamp" =>{StructField(name,TimestampType,true)}
case "java.math.BigDecimal" =>{StructField(name,DecimalType(10,0),true)}
}
}
/**
* @author xxx xxx@startdt.com
* @param rs ResultSet
* @param sparkSession
* @return DataFrame
*/
def createResultSetToDF(rs:ResultSet,sparkSession: SparkSession):DataFrame= {
val rsmd = rs.getMetaData
val columnTypeList = new util.ArrayList[String]
val rowSchemaList = new util.ArrayList[StructField]
for (i <- 1 to rsmd.getColumnCount) {
var temp = rsmd.getColumnClassName(i)
temp = temp.substring(temp.lastIndexOf(".") + 1)
if ("Integer".equals(temp)) {
temp = "Int";
}
columnTypeList.add(temp)
rowSchemaList.add(createStructField(rsmd.getColumnName(i), rsmd.getColumnClassName(i)))
}
val rowSchema = StructType(rowSchemaList)
//ResultSet反射类对象
val rsClass = rs.getClass
var count = 1
val resultList = new util.ArrayList[Row]
var totalDF = sparkSession.createDataFrame(new util.ArrayList[Row], rowSchema)
while (rs.next()) {
count = count + 1
// val temp = new util.ArrayList[Object]
val buffer = new ArrayBuffer[Any]()
for (i <- 0 to columnTypeList.size() - 1) {
val method = rsClass.getMethod("get" + columnTypeList.get(i), "aa".getClass)
buffer+=method.invoke(rs, rsmd.getColumnName(i + 1))
}
resultList.add(Row(buffer: _*))
if (count % 100000 == 0) {
val tempDF = sparkSession.createDataFrame(resultList, rowSchema)
totalDF = totalDF.union(tempDF).distinct()
resultList.clear()
}
}
val tempDF = sparkSession.createDataFrame(resultList, rowSchema)
totalDF = totalDF.union(tempDF)
totalDF
}
|