一. 只计算AUC值
直接计算auc值比较简单,,直接用pyspark.ml.evaluation即可:
evaluator = BinaryClassificationEvaluator(rawPredictionCol="probability", labelCol="label")
areaUnderPR = evaluator.evaluate(dataSetWithProba, {evaluator.metricName: "PR"})
二. 获取ROC曲线
1. 如果只需要获取train的ROC曲线点,直接用summnary即可
trainingSummary = LrModel.summary
print('ROC曲线绘图点:')
trainingSummary.roc.show()
print("AUC:"trainingSummary.areaUnderROC) 0.646777751909773
2. 如果需要该模型对其他数据集预测结果的ROC曲线
方法一: 利用sklearn
注意:需要转换数据格式
def get_prob_target(dataSetWithProba):
return dataSetWithProba.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label']))).collect()
def get_roc_curve(dataSetWithProba):
dataSetWithProba=get_prob_target(dataSetWithProba)
y_label=[i[1] for i in dataSetWithProba]
y_pred=[i[0] for i in dataSetWithProba]
return y_label,y_pred
from sklearn import metrics
fpr, tpr, thersholds = roc_curve(y_label, y_pred)
auc=metrics.auc(fpr,tpr)
print('ROC曲线绘图点:')
print("AUC: ",auc)
方法二:利用pyspark.mllib.evaluation
Scala中BinaryClassificationMetrics函数提供了提取ROC曲线的方式,但pyspark中没有提供,因此需要先从Scala模块中借用
from pyspark.mllib.evaluation import BinaryClassificationMetrics
class CurveMetrics(BinaryClassificationMetrics):
def __init__(self, *args):
super(CurveMetrics, self).__init__(*args)
def _to_list(self, rdd):
points = []
for row in rdd.collect():
points += [(float(row._1()), float(row._2()))]
return points
def get_curve(self, method):
rdd = getattr(self._java_model, method)().toJavaRDD()
return self._to_list(rdd)
preds = sdf.select(label,'prediction').rdd.map(lambda row: (float(row['prediction']), float(row[label])))
points = CurveMetrics(preds).get_curve('roc')
fpr = [x[0] for x in points]
tpr = [x[1] for x in points]
auc_roc = BinaryClassificationMetrics(preds).areaUnderROC
plt.figure()
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(fpr, tpr)
result['figure'] = {'title': 'ROC曲线',
'AUROC': auc_roc,
'x': [1-f for f in fpr],
'y': tpr,
'xlabel': '特异度',
'ylabel': '灵敏度'}
比较: 采用相同数据测试后,两种方法得到的auc值一样,但pyspark.mllib.evaluation比sklearn.metrics得到的曲线点多。
参考: pySpark提取ROC曲线
|