![强化学习实验中的绘图技巧-使用seaborn绘制paper中的图片](https://img-blog.csdnimg.cn/img_convert/5f04b12a5afa6a7d7e16223852947892.png)
seaborn可以说是matplotlib的升级版,使用seaborn绘制折线图时参数数据可以传递ndarray或者pandas。
1.从一个演示示例开始
1.1 极简示例
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set()
rewards = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
plt.plot(rewards)
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Lxe97p98-1649407821994)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408121343793.png)]](https://img-blog.csdnimg.cn/bf07fc0d67824795a9f998d1a3443229.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
可以看一下如果把sns.set() 注释掉的效果
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-blcKcK5x-1649407821995)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408121510096.png)]](https://img-blog.csdnimg.cn/615e51ae29784c1a89f2eb56061b71d3.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
1.2 使用sns.lineplot
加上x,y轴的label和标题
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns;
sns.set()
rewards = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
sns.lineplot(x=range(len(rewards)),y=rewards)
plt.xlabel("episode")
plt.ylabel("reward")
plt.title("data")
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RapnXGfP-1649407821996)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408122013525.png)]](https://img-blog.csdnimg.cn/b7fc10e8851d43b584374423f70bcf82.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
1.3 绘制rewards聚合图
当我们对同一实验作出多次得到一组rewards时,如下:
import numpy as np
rewards1 = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
rewards2 = np.array([0, 0,0.1,0.4,0.5,0.5,0.55,0.8,0.9,1])
rewards3 = np.vstack((rewards1,rewards2))
rewards4 = np.concatenate((rewards1,rewards2))
print(np.shape(rewards3))
print(rewards3)
print(np.shape(rewards4))
print(rewards4)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7M11zdEj-1649407821997)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408123520436.png)]](https://img-blog.csdnimg.cn/501a9d7574f34437a498edc268bbada4.png)
我们希望绘制出聚合图,但是sns.lineplot 无法输入一维以上的数据,我们可以将它们全部转为一维,虽然有些难看:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns;
sns.set()
rewards1 = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
rewards2 = np.array([0, 0,0.1,0.4,0.5,0.5,0.55,0.8,0.9,1])
rewards=np.concatenate((rewards1,rewards2))
episode1=range(len(rewards1))
episode2=range(len(rewards2))
episode=np.concatenate((episode1,episode2))
sns.lineplot(x=episode,y=rewards)
plt.xlabel("episode")
plt.ylabel("reward")
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EP5UMBkH-1649407821997)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408123705339.png)]](https://img-blog.csdnimg.cn/2a374fc74c6a40fda5af24a5e9f2c84e.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
1.4 使用pandas传参
上面都是用ndarray传参,用pandas传参,就需要先把array转成DataFrame形式,如下:
import numpy as np
import pandas as pd
rewards1 = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
rewards2 = np.array([0, 0,0.1,0.4,0.5,0.5,0.55,0.8,0.9,1])
rewards=np.vstack((rewards1,rewards2))
df = pd.DataFrame(rewards).melt(var_name='episode',value_name='reward')
print(df)
上述转化方法,这样无论rewards 多少维都不影响最终的绘图方式,其中melt 方法将所有维合并成一列,var_name='episode',value_name='reward' 则更改对应的列名,转化结果如下:
![](https://img-blog.csdnimg.cn/b2d76f9f65de41f99552f332eed6f7b1.png)
完整的绘图程序:
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
rewards1 = np.array([0, 0.1,0,0.2,0.4,0.5,0.6,0.9,0.9,0.9])
rewards2 = np.array([0, 0,0.1,0.4,0.5,0.5,0.55,0.8,0.9,1])
rewards=np.vstack((rewards1,rewards2))
df = pd.DataFrame(rewards).melt(var_name='episode',value_name='reward')
sns.lineplot(x="episode", y="reward", data=df)
plt.show()
这里的x,y不再传入数组,而是传入DataFrame中对应的列名,类似于python字典中的键,结果如下:
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-X8NljSM0-1649407821998)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408124815673.png)]](https://img-blog.csdnimg.cn/bbe9e358e4ac4a87880084b28a9158cc.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
1.5 一个稍微复杂的示例
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def get_data():
'''获取数据
'''
basecond = np.array([[18, 20, 19, 18, 13, 4, 1],[20, 17, 12, 9, 3, 0, 0],[20, 20, 20, 12, 5, 3, 0]])
cond1 = np.array([[18, 19, 18, 19, 20, 15, 14],[19, 20, 18, 16, 20, 15, 9],[19, 20, 20, 20, 17, 10, 0]])
cond2 = np.array([[20, 20, 20, 20, 19, 17, 4],[20, 20, 20, 20, 20, 19, 7],[19, 20, 20, 19, 19, 15, 2]])
cond3 = np.array([[20, 20, 20, 20, 19, 17, 12],[18, 20, 19, 18, 13, 4, 1], [20, 19, 18, 17, 13, 2, 0]])
return basecond, cond1, cond2, cond3
data = get_data()
label = ['algo1', 'algo2', 'algo3', 'algo4']
df=[]
for i in range(len(data)):
df.append(pd.DataFrame(data[i]).melt(var_name='episode',value_name='loss'))
df[i]['algo']= label[i]
df=pd.concat(df)
print(df)
sns.lineplot(x="episode", y="loss", hue="algo", style="algo",data=df)
plt.title("some loss")
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ImRfqS3p-1649407821998)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408125342434.png)]](https://img-blog.csdnimg.cn/1094649cf8694066b2dc724423a0269a.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5A2T4qEE-1649407821999)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408131203910.png)]](https://img-blog.csdnimg.cn/efd0e8f4ca944cef84c0e42f86b4fc61.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_15,color_FFFFFF,t_70,g_se,x_16)
2.读取csv文件并绘图
kaggle上一个酒店房间预定的数据,数据和本篇文章的代码都可以从这个链接获取:https://www.jianguoyun.com/p/Ddc6RhEQnNm0CRjc2aAE。
2.1 简单示例
读取数据
import pandas as pd
df=pd.read_csv('hotel_bookings.csv')
print(df.head())
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FY5E8Yf5-1649407821999)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408160452550.png)]](https://img-blog.csdnimg.cn/fdcf1cd4a082413084792cd3a6a2c134.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iSNEqFJE-1649407821999)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408160534023.png)]](https://img-blog.csdnimg.cn/2d140b4be4154ddca708c003eedd2a4e.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
我们这里主要看两个数据,一个是arrival_date_month,一个是stays_in_week_nights,分别表示客人到来的月份和住的时间。使用seaborn的lineplot的时候,调用API的方式有点不一样,这里x 和y 是直接指定我们数据的索引,x 这里就是df['arrival_date_month'] 这个数据,最后通过data参数来指定我们要传入的数据。
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
df=pd.read_csv('hotel_bookings.csv')
sns.lineplot(x="arrival_date_month",y="stays_in_week_nights",data=df)
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CcHtjwcM-1649407822000)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408161730845.png)]](https://img-blog.csdnimg.cn/70c00d25df8c40e690fbc991e3086a55.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
2.2 复杂示例
下面来看一个更加复杂的例子。我们希望将几个月内的住宿情况可视化,但我们也希望将入住年份考虑在内。这时候画图需要将月份、年份和入住情况三个数据都表示在图上。
import pandas as pd
df=pd.read_csv('hotel_bookings.csv')
df=df[['arrival_date_year','arrival_date_month','stays_in_week_nights']]
print(df)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bPo0Vm8l-1649407822000)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408162131182.png)]](https://img-blog.csdnimg.cn/247693a4e7cd42c8a555791c44d648fc.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
使用pivot_table,也就是透视图(excel中)来表示数据,pivot_table的作用就是将我们设定的index作为索引,然后去匹配我们设定的列,我们设定的value值也就是中间部分要显示的内容。
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
df=pd.read_csv('hotel_bookings.csv')
df=df[['arrival_date_year','arrival_date_month','stays_in_week_nights']]
df_wide=df.pivot_table(index='arrival_date_month',columns='arrival_date_year',values='stays_in_week_nights')
print(df_wide)
sns.lineplot(data=df_wide)
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aCG6I4WG-1649407822000)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408163228453.png)]](https://img-blog.csdnimg.cn/c44aa3c62624493982700a002d7368d8.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NJoj6q7B-1649407822001)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408163418373.png)]](https://img-blog.csdnimg.cn/860c16f43d254883b64820228372e4f3.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
我们也可以按照在原始的csv文件中,arrival_date_month 的顺序来画图,也就是上面我们设定的order=df['arrival_date_month'] 的作用。
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
df=pd.read_csv('hotel_bookings.csv')
df=df[['arrival_date_year','arrival_date_month','stays_in_week_nights']]
order=df['arrival_date_month']
df_wide=df.pivot_table(index='arrival_date_month',columns='arrival_date_year',values='stays_in_week_nights')
df_wide=df_wide.reindex(order,axis=0)
print(df_wide)
sns.lineplot(data=df_wide)
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FIu1SSyx-1649407822001)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408163936201.png)]](https://img-blog.csdnimg.cn/0e000779b54e4647aafdbde51716fa38.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KWBJvxIv-1649407822001)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408164032058.png)]](https://img-blog.csdnimg.cn/40f0f1f6968348618c41af43d48d0dd5.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
更为简洁的方式
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
df=pd.read_csv('hotel_bookings.csv')
sns.lineplot(x="arrival_date_month",y="stays_in_week_nights",hue="arrival_date_year",data=df)
plt.show()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1ZIV2YnL-1649407822002)(C:\Users\admin\AppData\Roaming\Typora\typora-user-images\image-20220408164531748.png)]](https://img-blog.csdnimg.cn/fa46bd17da52429ca30114a58a997ef5.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5bCP5biF5ZCW,size_20,color_FFFFFF,t_70,g_se,x_16)
参考资料:
https://zhuanlan.zhihu.com/p/147847062
https://www.guyuehome.com/36179
|