Python-matplotlib 条形统计图
效果图展示如下: 该代码可以处理多个实验多组观测值的展示,代码如下:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import MultipleLocator
def plot_bar(experiment_name, bar_name, bar_value, error_value=None,):
"""
Args:
experiment_name: x_labels
bar_name: legend name
bar_value: list(len(experiment_name), each element contains a np.array(),
which contains bar value in each group
error_value: list(len(experiment_name), each element contains a np.array(),
which contains error value in each group
Returns:
"""
colors = ['lightsteelblue', 'cornflowerblue', 'royalblue', 'blue', 'mediumblue', 'darkblue', 'navy', 'midnightblue',
'lavender', ]
assert len(bar_value[0]) <= len(colors)
plt.rcParams['axes.unicode_minus'] = False
plt.style.use('seaborn')
font = {'weight': 'normal', 'size': 20, }
font_title = {'weight': 'normal', 'size': 28, }
width = 0.2
x_bar = np.arange(len(experiment_name))
plt.figure(figsize=(10, 9))
ax = plt.subplot(111)
bar_groups = []
value = []
for i in range(len(bar_value[0])):
for j in range(len(experiment_name)):
value.append(bar_value[j][i])
group = ax.bar(x_bar - (len(experiment_name)-3-i)*width, copy.deepcopy(value), width=width, color=colors[i], label=bar_name[i])
bar_groups.append(group)
value.clear()
i = j = 0
for bars in bar_groups:
j = 0
for rect in bars:
x = rect.get_x()
height = rect.get_height()
if error_value:
ax.errorbar(x + width / 2, height, yerr=error_value[j][i], fmt="-", ecolor="black",
elinewidth=1.2, capsize=2,
capthick=1.2)
j += 1
i += 1
plt.xticks(fontsize=15)
plt.yticks(fontsize=18)
ax.set_xticks(x_bar)
ax.set_xticklabels(experiment_name, fontdict=font)
ax.set_ylabel("Episode Cost", fontdict=font_title)
ax.set_xlabel('Experiment', fontdict=font_title)
ax.grid(False)
ax.set_ylim(0, 7.5)
y_major_locator = MultipleLocator(2.5)
ax.yaxis.set_major_locator(y_major_locator)
plt.suptitle("Cost Comparison", fontsize=30, horizontalalignment='center')
plt.subplots_adjust(left=0.11, bottom=0.1, right=0.95, top=0.93, wspace=0.1, hspace=0.2)
ax.spines['bottom'].set_linewidth('2.0')
ax.legend(loc='upper left', frameon=True, fontsize=19.5)
plt.show()
plt.legend()
if __name__ == "__main__":
test_experiment_name = ["Test 1", "Test 2", "Test 3", "Test 4"]
test_bar_name = ['A', "B", "C"]
test_bar_value = [
np.array([1, 2, 3]),
np.array([4, 5, 6]),
np.array([3, 2, 4]),
np.array([5, 2, 2])
]
test_error_value = [
np.array([1, 1, 2]),
np.array([0.2, 0.6, 1]),
np.array([0, 0, 0]),
np.array([0.5, 0.2, 0.2])
]
plot_bar(test_experiment_name, test_bar_name, test_bar_value, test_error_value)
|