IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【树模型与集成学习】(task2)代码实现CART树(更新ing) -> 正文阅读

[人工智能]【树模型与集成学习】(task2)代码实现CART树(更新ing)

作者:token keyword

学习心得

task2学习GYH大佬的回归CART树,并在此基础上改为分类CART树。
更新ing。。

一、回顾决策树算法

在这里插入图片描述

在这里插入图片描述

二、代码实践

from CART import DecisionTreeRegressor
from CARTclassifier import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor as dt
from sklearn.tree import DecisionTreeClassifier as dc
from sklearn.datasets import make_regression
from sklearn.datasets import make_classification


if __name__ == "__main__":

    # 模拟回归数据集
    X, y = make_regression(
        n_samples=200, n_features=10, n_informative=5, random_state=0
    )
    # 回归树
    my_cart_regression = DecisionTreeRegressor(max_depth=2)
    my_cart_regression.fit(X, y)
    res1 = my_cart_regression.predict(X)
    importance1 = my_cart_regression.feature_importances_
    
    sklearn_cart_r = dt(max_depth=2)
    sklearn_cart_r.fit(X, y)
    res2 = sklearn_cart_r.predict(X)
    importance2 = sklearn_cart_r.feature_importances_

    # 预测一致的比例
    print(((res1-res2)<1e-8).mean())
    # 特征重要性一致的比例
    print(((importance1-importance2)<1e-8).mean())
    
    
    
    # 模拟分类数据集
    X, y = make_classification(
        n_samples=200, n_features=10, n_informative=5, random_state=0
    )
    # 分类树
    my_cart_classification = DecisionTreeClassifier(max_depth=2)
    my_cart_classification.fit(X, y)
    res3 = my_cart_classification.predict(X)
    importance3 = my_cart_classification.feature_importances_
    
    sklearn_cart_c = dc(max_depth=2)
    sklearn_cart_c.fit(X, y)
    res4 = sklearn_cart_c.predict(X)
    importance4 = sklearn_cart_c.feature_importances_

    # 预测一致的比例
    print(((res3-res4)<1e-8).mean())
    # 特征重要性一致的比例
    print(((importance3-importance4)<1e-8).mean())
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 17 10:46:08 2021

@author: 86493
"""
import numpy as np
from collections import Counter

def MSE(y):
    return ((y - y.mean())**2).sum() / y.shape[0]

# 基尼指数
def Gini(y):
    c = Counter(y)
    return 1 - sum([(val / y.shape[0]) ** 2 for val in c.values()])

class Node:
    def __init__(self, depth, idx):
        self.depth = depth
        self.idx = idx

        self.left = None
        self.right = None
        self.feature = None
        self.pivot = None


class Tree:
    def __init__(self, max_depth):
        self.max_depth = max_depth
        self.X = None
        self.y = None
        self.feature_importances_ = None

    def _able_to_split(self, node):
        return (node.depth < self.max_depth) & (node.idx.sum() >= 2)

    def _get_inner_split_score(self, to_left, to_right):
        total_num = to_left.sum() + to_right.sum()
        left_val = to_left.sum() / total_num * Gini(self.y[to_left])
        right_val = to_right.sum() / total_num * Gini(self.y[to_right])
        return left_val + right_val

    def _inner_split(self, col, idx):
        data = self.X[:, col]
        best_val = np.infty
        for pivot in data[:-1]:
            to_left = (idx==1) & (data<=pivot)
            to_right = (idx==1) & (~to_left)
            if to_left.sum() == 0 or to_left.sum() == idx.sum():
                continue
            Hyx = self._get_inner_split_score(to_left, to_right)
            if best_val > Hyx:
                best_val, best_pivot = Hyx, pivot
                best_to_left, best_to_right = to_left, to_right
        return best_val, best_to_left, best_to_right, best_pivot

    def _get_conditional_entropy(self, idx):
        best_val = np.infty
        for col in range(self.X.shape[1]):
            Hyx, _idx_left, _idx_right, pivot = self._inner_split(col, idx)
            if best_val > Hyx:
                best_val, idx_left, idx_right = Hyx, _idx_left, _idx_right
                best_feature, best_pivot = col, pivot
        return best_val, idx_left, idx_right, best_feature, best_pivot

    def split(self, node):
        # 首先判断本节点是不是符合分裂的条件
        if not self._able_to_split(node):
            return None, None, None, None
        # 计算H(Y)
        entropy = Gini(self.y[node.idx==1])
        # 计算最小的H(Y|X)
        (
            conditional_entropy,
            idx_left,
            idx_right,
            feature,
            pivot
        ) = self._get_conditional_entropy(node.idx)
        # 计算信息增益G(Y, X)
        info_gain = entropy - conditional_entropy
        # 计算相对信息增益
        relative_gain = node.idx.sum() / self.X.shape[0] * info_gain
        # 更新特征重要性
        self.feature_importances_[feature] += relative_gain
        # 新建左右节点并更新深度
        node.left = Node(node.depth+1, idx_left)
        node.right = Node(node.depth+1, idx_right)
        self.depth = max(node.depth+1, self.depth)
        return idx_left, idx_right, feature, pivot

    def build_prepare(self):
        self.depth = 0
        self.feature_importances_ = np.zeros(self.X.shape[1])
        self.root = Node(depth=0, idx=np.ones(self.X.shape[0]) == 1)

    def build_node(self, cur_node):
        if cur_node is None:
            return
        idx_left, idx_right, feature, pivot = self.split(cur_node)
        cur_node.feature, cur_node.pivot = feature, pivot
        self.build_node(cur_node.left)
        self.build_node(cur_node.right)

    def build(self):
        self.build_prepare()
        self.build_node(self.root)

    def _search_prediction(self, node, x):
        if node.left is None and node.right is None:
            # return self.y[node.idx].mean()
            return self.y[node.idx].min()
        if x[node.feature] <= node.pivot:
            node = node.left
        else:
            node = node.right
        return self._search_prediction(node, x)

    def predict(self, x):
        return self._search_prediction(self.root, x)


class DecisionTreeClassifier:
    """
    max_depth控制最大深度,类功能与sklearn默认参数下的功能实现一致
    """

    def __init__(self, max_depth):
        self.tree = Tree(max_depth=max_depth)

    def fit(self, X, y):
        self.tree.X = X
        self.tree.y = y
        self.tree.build()
        self.feature_importances_ = (
            self.tree.feature_importances_ 
            / self.tree.feature_importances_.sum()
        )
        return self

    def predict(self, X):
        return np.array([self.tree.predict(x) for x in X])

输出结果如下,可见在误差范围内,实现的分类树和回归树均和sklearn实现的模块近似。

1.0
1.0
1.0
1.0

Reference

(0)datawhale notebook
(1)CART决策树(Decision Tree)的Python源码实现
(2)https://github.com/RRdmlearning/Decision-Tree
(3)《机器学习技法》—决策树

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-18 17:24:05  更:2021-10-18 17:24:16 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 11:05:40-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码