IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 数据结构与算法 -> 《机器学习》西瓜书课后习题9.4——python实现K-means算法 -> 正文阅读

[数据结构与算法]《机器学习》西瓜书课后习题9.4——python实现K-means算法

作者:recommend-box insert-baidu-box

《机器学习》西瓜书课后习题9.4——python实现K-means算法

9.4 试编程实现k均值算法,设置三组不同的k值、三组不同的初始中心点,在西瓜数据集4.0上进行实验比较,
并讨论什么样的初始中心有利于取得好结果.

本文主要适用python语言编程实现了K-means算法的过程,并使用了西瓜数据集4.0作为测试数据,在初始化均值向量时使用随机选择的方法,因此相同参数的情况下代码每次运行的结果可能会有所不同。最后,为了验证聚类效果,可视化了最终的结果集,可以发现的是,随着迭代的此时的增多,聚类的效果更好,直至趋于稳定。

具体的算法伪代码和原理部分参见《机器学习》周志华 P203。

西瓜数据集4.0内容:

编号,密度,含糖率
1,0.697,0.460
2,0.774,0.376
3,0.634,0.264
4,0.608,0.318
5,0.556,0.215
6,0.403,0.237
7,0.481,0.149
8,0.437,0.211
9,0.666,0.091
10,0.243,0.267
11,0.245,0.057
12,0.343,0.099
13,0.639,0.161
14,0.657,0.198
15,0.360,0.370
16,0.593,0.042
17,0.719,0.103
18,0.359,0.188
19,0.339,0.241
20,0.282,0.257
21,0.748,0.232
22,0.714,0.346
23,0.483,0.312
24,0.478,0.437
25,0.525,0.369
26,0.751,0.489
27,0.532,0.472
28,0.473,0.376
29,0.725,0.445
30,0.446,0.459

下面我将对代码逐段解释……

一、数据预处理

先将数据集下载并加载进来!代码如下:

def loadData(filename):
    data = open(filename, 'r', encoding='GBK')
    reader = csv.reader(data)
    headers = next(reader)
    dataset = []
    for row in reader:
        row[1] = float(row[1])
        row[2] = float(row[2])
        dataset.append([row[1],row[2]])

    return dataset

二、k-means算法

1、创建一个Kmeans类,所有的操作均封装

Kmeans类中以下的参数是必不可少的:

  • k:聚类的个数
  • train_data:数据集
  • epoch:迭代次数

在该类中,首先是对均值向量进行初始化,随机选择k个数据作为均值向量。其次,我们需要计算数据和均值向量之间的距离,并按最近距离进行划分类。然后,我们需要更新均值向量,具体做法是取每个类中的均值作为新的均值向量。如此,直到达到要求的迭代次数。

class Kmeans:
    k = 0
    train_data = []
    category_k = {}
    vector_k = []
    epoch = 0
    def __init__(self,k,train_data,epoch):
        self.train_data = np.array(train_data, dtype=float)
        self.epoch = epoch
        #  初始化均值向量
        for i in range(0,k):
           self.vector_k.append(self.train_data[random.randint(0,len(train_data)-1)])
        self.vector_k = np.array(self.vector_k, dtype=float)

        #  第i轮迭代
        for epoch_i in range(0,epoch):
            self.category_k = {}
            #  计算距离并分类
            for data_i in self.train_data:
                category_i = self.dist(data_i)
                if category_i not in self.category_k:
                    self.category_k[category_i] = [data_i]
                else:
                    self.category_k[category_i].append(data_i)

            #  更新均值向量
            self.update_category()

2、计算距离并分类

    #  计算第i个数据和所有均值向量的距离,并
    def dist(self,data_i):
        dist = (self.vector_k-data_i)**2
        mean_dist = dist.mean(axis=1)
        return mean_dist.argmin()

3、更新均值向量

    #  计算每个类别的均值并更新均值向量
    def update_category(self):
        for i in range(0,len(self.vector_k)):
            self.vector_k[i] = np.array(self.category_k[i]).mean(axis=0)

4、绘制图像

    #  绘制散点图
    def draw_scatter(self):
        for i in self.category_k:
            x = np.array(self.category_k[i])[:,0]
            y = np.array(self.category_k[i])[:,1]
            plt.scatter(x,y)
        #  绘制均值向量点的位置
        print(self.vector_k)
        x_mean = self.vector_k[:,0]
        y_mean = self.vector_k[:,1]

        plt.title("epoch = "+str(self.epoch))
        plt.xlabel('密度')
        plt.ylabel('含糖率')
        plt.scatter(x_mean,y_mean,marker='+')
        plt.show()

三、完整的程序源代码

'''
  9.4 试编程实现k均值算法,设置三组不同的k值、三组不同的初始中心点,在西瓜数据集4.0上进行实验比较,
      并讨论什么样的初始中心有利于取得好结果.
'''

import random
import numpy as np
import csv
import matplotlib.pyplot as plt

class Kmeans:
    k = 0
    train_data = []
    category_k = {}
    vector_k = []
    epoch = 0
    def __init__(self,k,train_data,epoch):
        self.train_data = np.array(train_data, dtype=float)
        self.epoch = epoch
        #  初始化均值向量
        for i in range(0,k):
           self.vector_k.append(self.train_data[random.randint(0,len(train_data)-1)])
        self.vector_k = np.array(self.vector_k, dtype=float)

        #  第i轮迭代
        for epoch_i in range(0,epoch):
            self.category_k = {}
            #  计算距离并分类
            for data_i in self.train_data:
                category_i = self.dist(data_i)
                if category_i not in self.category_k:
                    self.category_k[category_i] = [data_i]
                else:
                    self.category_k[category_i].append(data_i)

            #  更新均值向量
            self.update_category()

    def get_category(self):
        return self.category_k


    #  计算第i个数据和所有均值向量的距离,并
    def dist(self,data_i):
        dist = (self.vector_k-data_i)**2
        mean_dist = dist.mean(axis=1)
        return mean_dist.argmin()
        # print(mean_dist.argmin())

    #  计算每个类别的均值并更新均值向量
    def update_category(self):
        for i in range(0,len(self.vector_k)):
            self.vector_k[i] = np.array(self.category_k[i]).mean(axis=0)

    #  绘制散点图
    def draw_scatter(self):
        for i in self.category_k:
            x = np.array(self.category_k[i])[:,0]
            y = np.array(self.category_k[i])[:,1]
            plt.scatter(x,y)
        #  绘制均值向量点的位置
        print(self.vector_k)
        x_mean = self.vector_k[:,0]
        y_mean = self.vector_k[:,1]

        plt.title("epoch = "+str(self.epoch))
        plt.xlabel('密度')
        plt.ylabel('含糖率')
        plt.scatter(x_mean,y_mean,marker='+')
        plt.show()




def loadData(filename):
    data = open(filename, 'r', encoding='GBK')
    reader = csv.reader(data)
    headers = next(reader)
    dataset = []
    for row in reader:
        row[1] = float(row[1])
        row[2] = float(row[2])
        dataset.append([row[1],row[2]])

    return dataset

filename = '西瓜数据集4.0.csv'
traindata = loadData(filename)
kmeans = Kmeans(3,traindata,100)
print(kmeans.get_category())
kmeans.draw_scatter()

四、结果分析

在这里我们令k=3,迭代次数依次为100,500,1000次,然后根据图像来观察聚类效果,如图所示:

图中相同颜色的点表示划分为同一个簇,红色的“+”号代表最终的均值向量,标题上标注了迭代的次数
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  数据结构与算法 最新文章
【力扣106】 从中序与后续遍历序列构造二叉
leetcode 322 零钱兑换
哈希的应用:海量数据处理
动态规划|最短Hamilton路径
华为机试_HJ41 称砝码【中等】【menset】【
【C与数据结构】——寒假提高每日练习Day1
基础算法——堆排序
2023王道数据结构线性表--单链表课后习题部
LeetCode 之 反转链表的一部分
【题解】lintcode必刷50题<有效的括号序列
上一篇文章      下一篇文章      查看所有文章
加:2022-04-26 12:01:52  更:2022-04-26 12:03:59 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 7:46:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码