scala版本spark构建的Lr模型:
一、问题背景
??需要构建一个Lr模型来进行物品的Ctr预测。
二、解决方案
??由于我们训练的数据量较多,所以首先考虑采用spark来构建模型并测试训练,这样的效率较高。 ??*模型接口详情可以参考spark的scala的API文档:https://spark.apache.org/docs/latest/api/scala/org/apache/spark/index.html,整体代码如下:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorAssembler}
import org.apache.spark.sql.SparkSession
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder
import javax.xml.transform.stream.StreamResult
object CargoClinchLR {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
val path = "hdfs://xxxxxxxx"
val str_col = Array( "start_city_id", "end_city_id", "start_prov_id", "end_prov_id",
……)
val idx_col = for (c <- str_col) yield s"${c}_idx"
val num_col = Array("weight", "capacity", "distance")
val data = spark.read.parquet(path).na.fill("unknown").na.replace(str_col, Map("" -> "unknown"))
val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
val str_idxers = for (c <- str_col)
yield new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx").setHandleInvalid("skip").setStringOrderType("frequencyAsc")
val assember = new VectorAssembler().setInputCols(idx_col ++ num_col).setOutputCol("fea")
val lr = new LogisticRegression().setFeaturesCol("fea").setLabelCol("label")
val pip = new Pipeline().setStages(str_idxers ++ Array(assember, lr)).fit(train)
val lr_model = pip.stages.last.asInstanceOf[LogisticRegressionModel]
println(lr_model.coefficients)
val summary = lr_model.binarySummary
val precision = summary.weightedPrecision
val recall = summary.weightedRecall
val accuracy = summary.accurac
val auc = summary.areaUnderROC
println(s"train_acc =${auc}")
val eval = new BinaryClassificationEvaluator().setLabelCol("label")
.setMetricName("areaUnderROC")
eval.evaluate(pip.transform(test))
val auc = eval.evaluate(pip.transform(test))
println(s"eval_acc =${auc}")
val input_col = str_col ++ num_col
val pmml = new PMMLBuilder(data.schema, pip).build()
JAXBUtil.marshalPMML(pmml, new StreamResult("model"))
|