IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> Python:非线性SVM(KernelSVM)SMO算法 -> 正文阅读

[Python知识库]Python:非线性SVM(KernelSVM)SMO算法

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import rbf_kernel, polynomial_kernel


class KernelSVM():
    def __init__(self):
        self.C = None
        self.X = self.y = None
        self.gram = None
        self.tol = 1E-3
        self._alpha = self._w = self._prediction_cache = None

    def fit(self, X, y, C=1, kernel='rbf', gamma=0.1, epoch=200):
        self.X = np.asarray(X, np.float)
        self.y = np.asarray(y, np.float)
        self.nSample, self.nDim = self.X.shape
        self.C = C
        self.gamma = gamma

        if kernel == "rbf" or kernel == "gaussian":
            self.gram = rbf_kernel(X=X,Y=X,gamma=gamma)
        elif kernel == " ploy":
            self.gram = polynomial_kernel(X=X,Y=X,degree=1)

        self._alpha, self._w, self._prediction_cache = (np.zeros(self.nSample), np.zeros(self.nSample), np.zeros(self.nSample))
        self._b = 0.

        for i in range(epoch):
            idx1 = self._pick_alpha_1()
            if idx1 is None:
                return True
            idx2 = self._pick_alpha_2(idx1)
            self._update_alpha(idx1, idx2)

    def _pick_alpha_1(self):
        con1 = self._alpha > 0
        con2 = self._alpha < self.C
        err1 = self.y * self._prediction_cache - 1
        err2 = err1.copy()
        err3 = err1.copy()

        err1[(con1 & (err1 <= 0)) | (~con1 & (err1 > 0))] = 0
        err2[((~con1 | ~con2) & (err2 != 0)) | ((con1 & con2) & (err2 == 0))] = 0
        err3[(con2 & (err3 >= 0)) | (~con2 & (err3 < 0))] = 0
        err = err1 ** 2 + err2 ** 2 + err3 ** 2
        idx = np.argmax(err)
        if err[idx] < self.tol:
            return
        return idx

    def _pick_alpha_2(self, idx1):
        idx = np.random.randint(self.nSample)
        while idx == idx1:
            idx = np.random.randint(self.nSample)
        return idx

    def _get_lower_bound(self, idx1, idx2):
        if self.y[idx1] != self.y[idx2]:
            return max(0., self._alpha[idx2] - self._alpha[idx1])
        return max(0., self._alpha[idx2] + self._alpha[idx1] - self.C)

    def _get_upper_bound(self, idx1, idx2):
        if self.y[idx1] != self.y[idx2]:
            return min(self.C, self.C + self._alpha[idx2] - self._alpha[idx1])
        return min(self.C, self._alpha[idx2] + self._alpha[idx1])

    def _update_alpha(self, idx1, idx2):
        L, H = self._get_lower_bound(idx1, idx2), self._get_upper_bound(idx1, idx2)
        y1, y2 = self.y[idx1], self.y[idx2]
        e1 = self._prediction_cache[idx1] - self.y[idx1]
        e2 = self._prediction_cache[idx2] - self.y[idx2]
        eta = self.gram[idx1][idx1] + self.gram[idx2][idx2] - 2 * self.gram[idx1][idx2]
        a2_new = self._alpha[idx2] + (y2 * (e1 - e2)) / eta
        if a2_new > H:
            a2_new = H
        elif a2_new < L:
            a2_new = L

        a1_old, a2_old = self._alpha[idx1], self._alpha[idx2]
        da2 = a2_new - a2_old
        da1 = -y1 * y2 * da2
        self._alpha[idx1] += da1
        self._alpha[idx2] = a2_new
        self._update_dw_cache(idx1, idx2, da1, da2, y1, y2)
        self._update_db_cache(idx1, idx2, da1, da2, y1, y2, e1, e2)
        self._update_pred_cache(idx1, idx2)

    def _update_dw_cache(self, idx1, idx2, da1, da2, y1, y2):
        self._dw_cache = np.array([da1 * y1, da2 * y2])
        self._w[idx1] += self._dw_cache[0]
        self._w[idx2] += self._dw_cache[1]

    def _update_db_cache(self, idx1, idx2, da1, da2, y1, y2, e1, e2):
        gram_12 = self.gram[idx1][idx2]
        b1 = -e1 - y1 * self.gram[idx1][idx1] * da1 - y2 * gram_12 * da2
        b2 = -e2 - y1 * gram_12 * da1 - y2 * self.gram[idx2][idx2] * da2
        self._db_cache = (b1 + b2) * 0.5
        self._b += self._db_cache

    def _update_pred_cache(self, *args):
        self._prediction_cache += self._db_cache
        if len(args) == 1:
            self._prediction_cache += self._dw_cache * self.gram[args[0]]
        elif len(args) == len(self.gram):
            self._prediction_cache = self._dw_cache.dot(self.gram)
        else:
            self._prediction_cache += self._dw_cache.dot(self.gram[args, ...])

    def predict(self,X_test, get_raw_result=False):
        test_gram = rbf_kernel(X=self.X, Y=X_test, gamma=self.gamma)
        y_pred = self._w.dot(test_gram) + self._b
        if get_raw_result:
            return y_pred
        return np.sign(y_pred)


if __name__ == '__main__':
    # X,y = datasets.make_blobs(n_samples=1000, n_features=2, centers=2, cluster_std=[3.0,3.0],random_state=13)
    # X,y = datasets.make_moons(n_samples=1000,noise=0.2,random_state=10)
    X,y = datasets.make_circles(n_samples=1000, noise=0.2, factor=0.2, random_state=10)

    y[y==0] = -1
    plt.scatter(X[:,0],X[:,1],c=y)
    plt.show()
    Acc_List = []
    SKF = StratifiedKFold(n_splits=5, shuffle=True)
    for train_idx, test_idx in SKF.split(X=X,y=y):
        X_train = X[train_idx]
        y_train = y[train_idx]
        X_test = X[test_idx]
        y_test = y[test_idx]
        model = KernelSVM()
        model.fit(X=X_train, y=y_train,gamma=0.9, epoch=1000)
        y_pred = model.predict(X_test)
        Acc = accuracy_score(y_pred, y_test)
        Acc_List.append(Acc)
        print("Acc=", Acc)
    print("平均精度={}".format(np.mean(Acc_List)))
    print("标准差={}".format(np.std(Acc_List)))

感谢 射命丸咲 的知乎博客。

SMO算法的推导过程:

  • 北京大学射命丸咲的推导简洁明了(推荐)

https://zhuanlan.zhihu.com/p/27662928

  • Natavidad的推导详细清晰?

https://space.bilibili.com/501498946/

  • 大海老师的推导超级详细

https://www.bilibili.com/video/BV1mE411p7HE?from=search&seid=4583097772244902275

关于KKT条件的推导,推荐参考麦克马斯特大学RookieJ博士的文章。?

?https://zhuanlan.zhihu.com/p/65453337

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-07-24 11:23:48  更:2021-07-24 11:25:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/25 14:41:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计