1.关于聚类
import pandas as pd
import numpy as np
from pandas import Series,DataFrame
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
聚类的基本使用
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=150, n_features=2, centers=3 ,cluster_std=1.5, random_state=2)
X,y
(array([[ 1.53194956, -0.36022153],
[ -1.56430585, -9.59730336],
[ -1.08878137, -0.53673972],
[ 2.43442247, -0.15599663],
[ 0.48019529, -12.99688015],
[ 0.57898032, -2.06887799],
[ -0.72443515, -7.44202457],
[ -1.85225071, -3.98632318],
[ 2.54279316, -1.7870558 ],
[ -0.52731615, -10.74779592],
[ -2.29661532, -11.6406339 ],
[ 0.92141506, -9.98499137],
[ -5.35083116, -0.65431189],
[ -0.88989127, 0.11369336],
[ -0.59631184, -2.29097658],
[ -1.50195028, -5.4011869 ],
[ -0.76364841, -4.6539681 ],
[ 2.19201955, 0.60036835],
[ 7.15628849, -0.06187083],
[ 4.81890691, -3.50331202],
[ 0.65087822, -4.39797054],
[ -3.24158992, -4.41559955],
[ -1.13496627, -1.67121333],
[ 0.03606565, -2.04003449],
[ -2.01792323, -2.58566719],
[ 1.355409 , -0.54741367],
[ -3.12791644, -4.06556581],
[ -5.26927614, -9.6186543 ],
[ -0.02442698, -1.33977954],
[ -2.86703029, -10.84498679],
[ 2.29764685, -2.92418801],
[ -1.18679697, -1.80057881],
[ -0.73325486, -1.93333585],
[ -0.52577983, -11.34940749],
[ -0.50461407, -3.93251527],
[ 1.90846569, -0.6583068 ],
[ -2.06104996, -0.17628645],
[ -0.47151448, -10.37571491],
[ 1.26386427, -0.46380574],
[ -0.36309079, -9.40951948],
[ -0.53887254, -0.6449586 ],
[ -2.23212091, -8.718881 ],
[ -4.16374507, -3.50826293],
[ -0.70730261, -8.6320622 ],
[ 1.18048503, -0.15879893],
[ -1.41098559, -4.66354671],
[ -1.90907668, -9.67996871],
[ 2.03754653, -0.24742774],
[ -0.97378999, -7.371431 ],
[ 2.99659881, -0.83960814],
[ -3.4119278 , -9.71171816],
[ 0.10647516, -2.83784632],
[ 3.68213884, -1.93707213],
[ -1.88594036, -11.55825336],
[ 0.46218028, -8.90235829],
[ -0.04304745, -7.60915598],
[ -1.78833491, -9.83575141],
[ -0.95592795, -12.26939394],
[ -1.32676236, -4.41753005],
[ 0.22044687, -10.05311414],
[ -4.0709444 , -4.40679626],
[ -4.62768987, -2.99134472],
[ -4.36824992, -2.89757148],
[ -4.08223794, -6.88469836],
[ 0.77102877, 0.95860323],
[ -0.05463537, -2.68689003],
[ 0.44653092, -2.60752136],
[ -0.49365731, -8.378556 ],
[ 0.65278373, -5.68997024],
[ -3.32769271, -1.54225156],
[ -0.36011954, -2.18001056],
[ 2.2374372 , -0.3476192 ],
[ 1.5880298 , -1.7654783 ],
[ -2.31262163, -4.92277723],
[ -0.28638281, -2.50409338],
[ -0.62985746, -7.56390652],
[ 1.30709149, -4.99949807],
[ -3.88704121, -7.92023943],
[ 0.31190778, -0.52199607],
[ -0.70822817, -2.35468348],
[ -2.44971637, -2.95465548],
[ -1.71601202, -3.85030346],
[ -2.06618377, -5.41830673],
[ -1.46459731, -2.39530216],
[ -2.52380489, -9.34991004],
[ -1.83223015, -2.56988374],
[ -1.02782509, -3.59652323],
[ -2.51078608, -3.92019727],
[ -2.63990045, -3.03337678],
[ -1.73623162, -5.60353306],
[ -0.89524628, -10.96464394],
[ -5.15424798, -3.30552368],
[ -3.3851438 , -4.1251994 ],
[ 2.67007966, -1.70491528],
[ 2.12119683, -2.78419362],
[ -2.25997736, -8.21779094],
[ 1.74015978, -1.10379588],
[ 3.29089003, -4.27232081],
[ -1.54379575, -5.85414392],
[ -1.75036425, -8.32495776],
[ -1.33945732, -8.99247021],
[ -2.92821038, -7.10474478],
[ -1.00719928, -1.93003946],
[ 0.29073017, -3.17563261],
[ -1.28008731, -8.66794651],
[ 1.54082983, -0.1324291 ],
[ -1.84360609, -9.59318151],
[ -0.597949 , -0.40605237],
[ -2.23658448, -11.26289379],
[ -3.19324464, -4.3727003 ],
[ -1.21779287, -11.15836353],
[ -2.86763721, -4.67181627],
[ 0.82161761, -2.04081344],
[ -0.45292089, -6.04316334],
[ -0.709394 , -9.80717827],
[ -4.93225332, -9.31238561],
[ -0.23742255, -12.53167518],
[ -2.40190838, -9.46793749],
[ 2.65696448, -3.94092874],
[ 0.10261618, 0.4306987 ],
[ 0.77075118, -7.65464691],
[ -1.97310998, -8.95514262],
[ -1.23044866, -0.02408431],
[ -0.83889419, 1.41316281],
[ 1.89552328, -1.28806291],
[ 3.74624864, -0.63251734],
[ 0.50567512, -2.13390391],
[ -3.56899486, -6.43169397],
[ -1.30528349, -4.3866171 ],
[ -2.97980187, -8.83183653],
[ -1.85237668, -9.38174185],
[ 3.60596784, -1.96480346],
[ -1.59595363, -3.6022414 ],
[ -1.32254393, -3.49370015],
[ -4.34058653, -9.41209208],
[ -0.05036661, -4.47612317],
[ -1.60404567, -4.79404957],
[ -2.52020719, -2.82511188],
[ -2.5972638 , -9.71612662],
[ -0.82276679, -3.89556977],
[ 5.91286766, 0.16273983],
[ -1.85513922, -5.54901873],
[ 0.7183647 , 0.23622995],
[ -1.6836874 , -6.13442518],
[ -0.0856312 , -2.16867404],
[ -1.3087977 , -7.71897353],
[ -2.42206812, -2.94401336],
[ 0.50787945, -0.65781509],
[ -3.18469257, -3.55607882],
[ -1.50676754, -3.15467085]]),
array([1, 0, 2, 1, 0, 1, 0, 2, 1, 0, 0, 0, 2, 1, 2, 2, 2, 1, 1, 1, 1, 2,
1, 2, 2, 1, 2, 0, 2, 0, 1, 1, 1, 0, 2, 1, 2, 0, 1, 0, 1, 0, 2, 0,
1, 2, 0, 1, 0, 1, 0, 2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 2, 2, 0, 1, 1,
1, 0, 2, 2, 1, 1, 1, 2, 2, 0, 2, 0, 1, 1, 2, 2, 2, 2, 0, 2, 2, 2,
2, 0, 0, 2, 2, 1, 1, 0, 1, 1, 2, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 2,
0, 2, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 2, 0, 0, 1,
2, 2, 0, 2, 2, 2, 0, 2, 1, 2, 1, 0, 1, 0, 2, 1, 2, 2]))
sns.set()
X.shape
(150, 2)
y.shape
(150,)
plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.Accent_r)
<matplotlib.collections.PathCollection at 0x1a3478924f0>
?
?
from sklearn.cluster import KMeans
km = KMeans(n_clusters=3)
y_ = km.fit_predict(X)
y_
array([2, 1, 2, 2, 1, 2, 1, 0, 2, 1, 1, 1, 0, 2, 0, 0, 0, 2, 2, 2, 0, 0,
0, 2, 0, 2, 0, 1, 2, 1, 2, 0, 0, 1, 0, 2, 0, 1, 2, 1, 2, 1, 0, 1,
2, 0, 1, 2, 1, 2, 1, 0, 2, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 2, 0,
2, 1, 0, 0, 0, 2, 2, 0, 0, 1, 0, 1, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 2, 2, 1, 2, 2, 0, 1, 1, 1, 0, 0, 1, 2, 1, 2, 1, 0,
1, 0, 2, 0, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 0, 0, 1, 1, 2,
0, 0, 1, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 1, 0, 2, 0, 0])
cluster_centers = km.cluster_centers_
cluster_centers
array([[-1.80632868, -3.67173199],
[-1.62473796, -9.4792349 ],
[ 1.58887503, -1.06495221]])
plt.figure(figsize=(12,4))
ax1 = plt.subplot(1,2,1)
ax2 = plt.subplot(1,2,2)
ax1.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.cool)
ax2.scatter(X[:,0],X[:,1],c=y_,cmap=plt.cm.rainbow_r)
ax2.scatter(cluster_centers[:,0],cluster_centers[:,1],color ='red',marker = '*',s=300)
ax1.set_title('True')
ax2.set_title('Kmeans')
plt.show()
?
?
球队综合实力聚类分析
data = pd.read_csv('AsiaZoo.txt',header=None)
data
| 0 | 1 | 2 | 3 |
---|
0 | 中国 | 50 | 50 | 9 |
---|
1 | 日本 | 28 | 9 | 4 |
---|
2 | 韩国 | 17 | 15 | 3 |
---|
3 | 伊朗 | 25 | 40 | 5 |
---|
4 | 沙特 | 28 | 40 | 2 |
---|
5 | 伊拉克 | 50 | 50 | 1 |
---|
6 | 卡塔尔 | 50 | 40 | 9 |
---|
7 | 阿联酋 | 50 | 40 | 9 |
---|
8 | 乌兹别克斯坦 | 40 | 40 | 5 |
---|
9 | 泰国 | 50 | 50 | 9 |
---|
10 | 越南 | 50 | 50 | 5 |
---|
11 | 阿曼 | 50 | 50 | 9 |
---|
12 | 巴林 | 40 | 40 | 9 |
---|
13 | 朝鲜 | 40 | 32 | 17 |
---|
14 | 印尼 | 50 | 50 | 9 |
---|
data.columns = ['国家','2006年世界杯','2010年世界杯','2007年亚洲杯']
data
| 国家 | 2006年世界杯 | 2010年世界杯 | 2007年亚洲杯 |
---|
0 | 中国 | 50 | 50 | 9 |
---|
1 | 日本 | 28 | 9 | 4 |
---|
2 | 韩国 | 17 | 15 | 3 |
---|
3 | 伊朗 | 25 | 40 | 5 |
---|
4 | 沙特 | 28 | 40 | 2 |
---|
5 | 伊拉克 | 50 | 50 | 1 |
---|
6 | 卡塔尔 | 50 | 40 | 9 |
---|
7 | 阿联酋 | 50 | 40 | 9 |
---|
8 | 乌兹别克斯坦 | 40 | 40 | 5 |
---|
9 | 泰国 | 50 | 50 | 9 |
---|
10 | 越南 | 50 | 50 | 5 |
---|
11 | 阿曼 | 50 | 50 | 9 |
---|
12 | 巴林 | 40 | 40 | 9 |
---|
13 | 朝鲜 | 40 | 32 | 17 |
---|
14 | 印尼 | 50 | 50 | 9 |
---|
from mpl_toolkits.mplot3d import Axes3D
sns.set_style(style='white')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(10,6))
ax = plt.subplot(projection='3d')
ax.scatter3D(data['2006年世界杯'],data['2010年世界杯'],data['2007年亚洲杯'],s = 200,cmap = plt.cm.rainbow)
ax.set_xlabel('2006年世界杯')
ax.set_ylabel('2010年世界杯')
ax.set_zlabel('2007年亚洲杯')
plt.show()
?
?
X = data.iloc[:,1:]
X
| 2006年世界杯 | 2010年世界杯 | 2007年亚洲杯 |
---|
0 | 50 | 50 | 9 |
---|
1 | 28 | 9 | 4 |
---|
2 | 17 | 15 | 3 |
---|
3 | 25 | 40 | 5 |
---|
4 | 28 | 40 | 2 |
---|
5 | 50 | 50 | 1 |
---|
6 | 50 | 40 | 9 |
---|
7 | 50 | 40 | 9 |
---|
8 | 40 | 40 | 5 |
---|
9 | 50 | 50 | 9 |
---|
10 | 50 | 50 | 5 |
---|
11 | 50 | 50 | 9 |
---|
12 | 40 | 40 | 9 |
---|
13 | 40 | 32 | 17 |
---|
14 | 50 | 50 | 9 |
---|
km = KMeans(n_clusters=3)
y_ = km.fit_predict(X)
y_
array([0, 1, 1, 2, 2, 0, 0, 0, 2, 0, 0, 0, 2, 2, 0])
plt.figure(figsize=(10,6))
ax = plt.subplot(projection='3d')
ax.scatter3D(data['2006年世界杯'],data['2010年世界杯'],data['2007年亚洲杯'],s = 200,c = y_,cmap = plt.cm.rainbow)
ax.set_xlabel('2006年世界杯')
ax.set_ylabel('2010年世界杯')
ax.set_zlabel('2007年亚洲杯')
plt.show()
data['簇类'] = y_
data
| 国家 | 2006年世界杯 | 2010年世界杯 | 2007年亚洲杯 | 簇类 |
---|
0 | 中国 | 50 | 50 | 9 | 0 |
---|
1 | 日本 | 28 | 9 | 4 | 1 |
---|
2 | 韩国 | 17 | 15 | 3 | 1 |
---|
3 | 伊朗 | 25 | 40 | 5 | 2 |
---|
4 | 沙特 | 28 | 40 | 2 | 2 |
---|
5 | 伊拉克 | 50 | 50 | 1 | 0 |
---|
6 | 卡塔尔 | 50 | 40 | 9 | 0 |
---|
7 | 阿联酋 | 50 | 40 | 9 | 0 |
---|
8 | 乌兹别克斯坦 | 40 | 40 | 5 | 2 |
---|
9 | 泰国 | 50 | 50 | 9 | 0 |
---|
10 | 越南 | 50 | 50 | 5 | 0 |
---|
11 | 阿曼 | 50 | 50 | 9 | 0 |
---|
12 | 巴林 | 40 | 40 | 9 | 2 |
---|
13 | 朝鲜 | 40 | 32 | 17 | 2 |
---|
14 | 印尼 | 50 | 50 | 9 | 0 |
---|
data.groupby('簇类').groups
{0: [0, 5, 6, 7, 9, 10, 11, 14], 1: [1, 2], 2: [3, 4, 8, 12, 13]}
data.groupby('簇类').groups.items()
dict_items([(0, Int64Index([0, 5, 6, 7, 9, 10, 11, 14], dtype='int64')), (1, Int64Index([1, 2], dtype='int64')), (2, Int64Index([3, 4, 8, 12, 13], dtype='int64'))])
for _,indexes in data.groupby('簇类').groups.items():
countries = data.loc[indexes,'国家']
for country in countries:
print(country, end=' ')
print()
中国 伊拉克 卡塔尔 阿联酋 泰国 越南 阿曼 印尼
日本 韩国
伊朗 沙特 乌兹别克斯坦 巴林 朝鲜
2.kmeans中常见的错误
import pandas as pd
import numpy as np
from pandas import Series,DataFrame
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
a. k值的不合理不合适
X ,y = make_blobs(n_samples=150, n_features=2, centers=3, random_state=2, cluster_std=2)
sns.set()
def show_scatter(X,y):
plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.rainbow)
plt.show()
show_scatter(X,y)
?
?
km = KMeans(n_clusters=2)
y_ = km.fit_predict(X)
def show_predict(X,y,y_):
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(X[:,0],X[:,1],c=y,s=100,cmap=plt.cm.rainbow)
plt.title('True')
plt.subplot(122)
plt.scatter(X[:,0],X[:,1],c=y_,s=100,cmap=plt.cm.cool)
plt.title('Kmeans prediction')
plt.show()
show_predict(X,y,y_)
? ?
km = KMeans(n_clusters=7)
y_ = km.fit_predict(X)
show_predict(X,y,y_)
?
b. 数据存在偏差 (必然存在)
trans = [[0.6,-0.6],[-0.4,0.8]]
X1 = np.dot(X, trans)
X1
array([[ 1.04655426, -1.06619879],
[ 2.85894079, -6.71330586],
[ -0.7186756 , 0.88485489],
[ 1.65961265, -1.57033723],
[ 6.30764932, -11.97512202],
[ 1.19546232, -2.12639029],
[ 2.38135532, -5.08623838],
[ 0.51032678, -2.18392533],
[ 2.61620742, -3.39683023],
[ 4.30212858, -8.77008969],
[ 3.3628695 , -8.30701087],
[ 5.05428445, -9.11541647],
[ -4.06561027, 4.1690844 ],
[ -1.14367301, 1.37678308],
[ 0.61089302, -1.38030672],
[ 1.5451611 , -3.97335364],
[ 1.7372859 , -3.76696174],
[ 1.06229566, -0.56962559],
[ 5.38690504, -5.2474292 ],
[ 5.35243507, -7.04839453],
[ 2.49516333, -4.668274 ],
[ -0.3721972 , -1.53034875],
[ -0.38778277, -0.33105738],
[ 0.98295923, -1.61853714],
[ -0.36922777, -0.55735425],
[ 1.00515763, -1.12463797],
[ -0.46794307, -1.24791822],
[ -0.09364828, -3.77210397],
[ 0.56109582, -0.82320444],
[ 2.4821924 , -7.00198864],
[ 3.02656089, -4.41365421],
[ -0.3602524 , -0.42758268],
[ 0.07338504, -0.93202387],
[ 4.62421714, -9.41303776],
[ 1.55973854, -3.20463953],
[ 1.50674598, -1.68536932],
[ -1.68873222, 2.04715326],
[ 4.14832672, -8.41784462],
[ 0.88733094, -0.96222038],
[ 3.71976143, -7.4739751 ],
[ -0.45824364, 0.28673935],
[ 1.85619681, -5.24206996],
[ -1.59383419, 0.1752011 ],
[ 3.0297481 , -6.36931789],
[ 0.65795725, -0.5701764 ],
[ 1.22452476, -3.25930919],
[ 2.62721231, -6.52566558],
[ 1.39087515, -1.35036299],
[ 2.14422155, -4.8114547 ],
[ 2.47394653, -2.74926392],
[ 1.44186445, -5.35725076],
[ 1.46478647, -2.52586403],
[ 3.60769267, -4.46832419],
[ 3.64747318, -8.54767826],
[ 4.10949232, -7.59322003],
[ 3.01560224, -5.80962205],
[ 2.80688984, -6.78842721],
[ 4.77075809, -10.05023814],
[ 0.92348261, -3.10702502],
[ 4.52984205, -8.62730621],
[ -1.04037586, -0.857475 ],
[ -2.24067973, 1.09773636],
[ -2.08314016, 0.99020918],
[ -0.60212755, -1.80551487],
[ -0.26555557, 0.94928424],
[ 1.01817619, -2.27871059],
[ 1.37677927, -2.59498371],
[ 3.06546103, -6.26982751],
[ 3.42296609, -6.00517641],
[ -1.97353169, 1.60343801],
[ 0.50345313, -1.49365181],
[ 1.60422314, -1.61714642],
[ 1.84088873, -2.61000354],
[ 0.6414722 , -2.81451291],
[ 0.97249854, -1.85557453],
[ 2.52202118, -5.29190794],
[ 3.57816048, -5.79211897],
[ 0.1063184 , -3.06624939],
[ 0.15680059, -0.26272488],
[ 0.31812512, -1.40148269],
[ -0.51786853, -0.60550725],
[ 0.54677387, -2.14782857],
[ 1.10290489, -3.540228 ],
[ -0.02809505, -0.79695896],
[ 1.95939845, -5.68182042],
[ -0.22909115, -0.68907304],
[ 0.96197396, -2.42767921],
[ -0.05176867, -1.58656273],
[ -0.6280311 , -0.53732937],
[ 0.59138935, -2.3157436 ],
[ 4.12343675, -8.70705014],
[ -2.49436411, 1.18385863],
[ -0.64192038, -1.10574549],
[ 2.67422834, -3.41104287],
[ 2.8107372 , -4.12316684],
[ 1.56666362, -4.68528874],
[ 1.60969543, -2.02591295],
[ 4.54015959, -6.64625707],
[ 1.7532618 , -4.42303141],
[ 2.03150974, -5.2072905 ],
[ 2.71624194, -6.24802934],
[ 0.43845259, -2.96345309],
[ 0.08968327, -0.66659717],
[ 1.55513133, -3.07632843],
[ 2.59065863, -5.94936671],
[ 0.93216919, -0.83032442],
[ 2.63330227, -6.48546903],
[ -0.63292147, 0.58883382],
[ 3.20943279, -7.95211276],
[ -0.35640058, -1.52326577],
[ 3.9687166 , -8.6556471 ],
[ 0.06361389, -2.10280876],
[ 1.37460439, -2.29056461],
[ 1.85250741, -3.81133115],
[ 3.65480355, -7.62110191],
[ 0.01262668, -3.71503562],
[ 5.48544573, -10.90480912],
[ 2.11986363, -5.90523358],
[ 3.85627671, -5.78563176],
[ -0.51873656, 0.92091616],
[ 3.69090297, -6.50918461],
[ 2.18941176, -5.70129111],
[ -1.34263749, 1.50226615],
[ -1.79592572, 2.72208617],
[ 1.83226198, -2.34675524],
[ 2.96321796, -3.12808692],
[ 1.17149864, -2.13710711],
[ 0.44112921, -3.41892551],
[ 1.16139064, -3.04847928],
[ 1.318295 , -4.7644111 ],
[ 2.51351799, -6.25291692],
[ 3.56154592, -4.43696748],
[ 0.51052083, -1.97927577],
[ 0.67135991, -2.08222618],
[ 0.53913689, -4.29472262],
[ 2.21306072, -4.14788593],
[ 1.13967822, -3.24406418],
[ -0.6433511 , -0.41093476],
[ 2.09594684, -6.01368432],
[ 1.28551209, -2.91070882],
[ 4.27237602, -4.01310782],
[ 1.34145359, -3.84848977],
[ 0.07757892, 0.220884 ],
[ 0.91656719, -2.9240639 ],
[ 0.71699766, -1.70115019],
[ 2.0615714 , -4.91416057],
[ -0.50142572, -0.61627426],
[ 0.38601474, -0.56437584],
[ -0.78509038, -0.65904451],
[ 0.3431654 , -1.57321605]])
show_scatter(X1,y)
show_scatter(X,y)
km2 =KMeans(n_clusters=3)
y2_ = km2.fit_predict(X1)
show_predict(X , y ,y2_)
?
c. 标准偏差不相同cluster_std
元 万元 8000 0.8
X , y = make_blobs(n_samples=150, n_features=2, random_state=2, cluster_std=[1,2,4])
show_scatter(X,y)
?
km3 = KMeans(n_clusters=3)
y3_ = km3.fit_predict(X)
show_predict(X,y,y3_)
from sklearn.preprocessing import StandardScaler
ss_X = StandardScaler().fit_transform(X)
ss_X.std(axis=0)
array([1., 1.])
d. 样本数量不同
X,y = make_blobs(n_samples=1500, n_features=2, centers=3, random_state=5)
XA = X[y==0][:100]
XB = X[y==1][:35]
XC = X[y==2][:15]
XX = np.concatenate((XA,XB,XC))
y = np.array([0]*100 + [1]*35 + [2]*15)
XX.shape
(150, 2)
y.shape
(150,)
show_scatter(XX,y)
?
km4 = KMeans(n_clusters=3)
y4_ = km4.fit_predict(XX)
show_predict(XX, y, y4_)
?
?
e. 使用轮廓系数来判断聚类的效果
X ,y = make_blobs(n_samples=150, n_features=2, centers=3, random_state=3,cluster_std=3)
show_scatter(X,y)
?
from sklearn.metrics import silhouette_score
kmeans = KMeans(n_clusters=3)
y_ = kmeans.fit_predict(X)
silhouette_score(X,y_)
0.39860514202079084
def show_clusters_edge(kmeans, X):
xmin, xmax = X[:,0].min(), X[:,0].max()
ymin, ymax = X[:,1].min(), X[:,1].max()
x = np.linspace(xmin, xmax, 200)
y = np.linspace(ymin, ymax,200)
xx , yy = np.meshgrid(x, y)
kmeans.fit(X)
X_test = np.c_[xx.ravel(), yy.ravel()]
y_ = kmeans.predict(X)
y1_ = kmeans.predict(X_test)
plt.scatter(X_test[:,0],X_test[:,1], c=y1_,s=100, cmap=plt.cm.Accent)
plt.scatter(X[:,0],X[:,1],c=y_,s= 100,cmap=plt.cm.cool)
plt.title('silhouette_score:%.4f'%(silhouette_score(X,y_)))
plt.show()
kmeans = KMeans(n_clusters=3)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
kmeans = KMeans(n_clusters=4)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
kmeans = KMeans(n_clusters=5)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
kmeans = KMeans(n_clusters=6)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
kmeans = KMeans(n_clusters=7)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
kmeans = KMeans(n_clusters=2)
show_clusters_edge(kmeans,X)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:240: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0.0, flags=flags)
D:\software\anaconda\lib\site-packages\matplotlib\backends\backend_agg.py:203: RuntimeWarning: Glyph 65306 missing from current font.
font.set_text(s, 0, flags=flags)
|