IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 嵌套交叉验证 -> 正文阅读

[人工智能]嵌套交叉验证

  • 20210911

0. 引言

平时进行机器学习实验,大多数情况下都是使用train-test直接划分的方法,这种方法一般来说,对于数据量比较的数据集,影响不是很大,但是对于数据集比较小的数据集来说,就有所偏颇。(我记得这是某个书上说的,深度学习的课程上也有所提及)。而对于数据量比较少的数据集,更多的是用K折交叉验证。当然,这种方法,本质上也是一样的。对于编码实现来说,基本上就是几行代码的事情。

而且,平时一般来说,还会在训练集中划分一个验证集,通过验证集的效果来进行具体的参数选择。

但是如果这些代码都自己来进行编程的话,就有点太伤脑筋了。所以,一般都是直接调用库函数来实现。

而且,有一个问题,在于还要进行一些归一化的内容,所以需要考虑。这些也都是能够进行自动化的。

本篇文章主要介绍了交叉验证的代码,同时还包含参数选择和Pipeline等内容,这样可以保证对于预处理或者其他的一些参数都能有优化选择。

1. 交叉验证

如果是进行普通的交叉验证的话,其实处理完数据之后,直接将这部分数据按照一行代码:

cross_val_score( model, X, y)

即可。但是前面也提到,还要进行相关的参数选择,所以要对代码进行一些调整。

在文章[1]中给出了简单的步骤,具体代码如下:

# automatic nested cross-validation for random forest on a classification dataset
from numpy import mean
from numpy import std
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
# create dataset
X, y = make_classification(n_samples=1000, n_features=20, random_state=1, n_informative=10, n_redundant=10)
# configure the cross-validation procedure
cv_inner = KFold(n_splits=3, shuffle=True, random_state=1)
# define the model
model = RandomForestClassifier(random_state=1)
# define search space
space = dict()
space['n_estimators'] = [10, 100, 500]
space['max_features'] = [2, 4, 6]
# define search
search = GridSearchCV(model, space, scoring='accuracy', n_jobs=1, cv=cv_inner, refit=True)
# configure the cross-validation procedure
cv_outer = KFold(n_splits=10, shuffle=True, random_state=1)
# execute the nested cross-validation
scores = cross_val_score(search, X, y, scoring='accuracy', cv=cv_outer, n_jobs=-1)
# report performance
print('Accuracy: %.3f (%.3f)' % (mean(scores), std(scores)))

文章中,这段代码之前是通过自己来编程实现相同的功能;利用KFlod实现外层的数据集循环。具体可以看文章内容。上述代码中,比较容易发生的歧义的地方就是GridSearchCV;但是这个部分是可以看做一个单独的模型。而且从官方的文档来看,其中refit参数的默认值是True,使用效果最好的参数,并且重新在整个数据集上进行训练。这样就能理解了。

2. Pipeline

先上代码:


#X,y是整个数据集

parameters = {
       "lda__n_components" : list(range(1, 35)),
        "rf__min_samples_leaf": [1, 2, 4],
        "rf__min_samples_split":[2, 5, 10],
        "rf__max_depth": [int(x) for x in np.linspace(10, 110, num = 10)]
}

steps = [
            ("min", StandardScaler()), 
            ('lda', LinearDiscriminantAnalysis()),
            ('rf', RandomForestClassifier(n_jobs = -1)),
]
    
model = Pipeline(steps = steps)
    
inner_cv = StratifiedKFold(n_splits=10, shuffle= True)
grid_model = GridSearchCV(
        model, 
        parameters,
        scoring='accuracy', n_jobs=-1, cv = inner_cv,
)
outer_cv = StratifiedKFold(n_splits=10, shuffle= True)#, random_state=1)

n_scores = cross_val_score(
        grid_model, X, y, 
        scoring = 'accuracy', cv = outer_cv, 
        n_jobs=-1, error_score='raise'
)

print('Accuracy: %.3f (%.3f)' % (mean(n_scores), std(n_scores)))

虽然前面的代码能够实现交叉验证,但是一般来说,对于数据集还要进行一些预处理,这个问题就需要Pipeline来解决;但是同时还要进行参数的选择,因为在弄Pipeline的时候,已经进行了命名。所以参数在选择的时候,参数字典的键值是在Pipeline的名字作为前缀。在文章[2]中,代码差不多。
文章[3]中实现了相同功能的代码。

在文章[4]中,简单介绍了嵌套交叉验证的具体内容,而且说明了,一般使用嵌套交叉验证之后,效果很比没使用嵌套交叉验证差一点,但是这种效果也是更符合规则的。

参考

[1]Nested Cross-Validation for Machine Learning with Python
[2]Putting together sklearn pipeline+nested cross-validation for KNN regression
[3]Python – Nested Cross Validation for Algorithm Selection
[4]Nested cross-validation

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-12 13:09:55  更:2021-09-12 13:11:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/21 18:26:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码