matplotlib绘制模型对比折线图

1 minute read

Published:

直接看代码和结果一目了然:

Code

''' chatgpt : draw figures using matplotlib '''
'''Pyplot tutorial: https://www.matplotlib.org.cn/tutorials/introductory/pyplot.html#working-with-multiple-figures-and-axes'''

import matplotlib.pyplot as plt


def plot_model_comparison(data_x, data_y, model_labels, colors, markers, linestyles, x_label, y_label, title):
    """
    绘制模型性能对比的折线图, x,y对应两种性能指标, 颜色对应'同类模型结构'的不同设置

    参数:
    - data_x: 二维数组,每行代表一个模型的性能数据
    - data_y: 二维数组,每行代表一个模型的性能数据
        data_x,data_y 每个元素对应的模型要相同
    - model_labels: 模型标签列表
    - colors : Hex 编码
    - markers :
                'o': 圆圈      '1': 下指针
                'v': 倒三角形  '2': 上指针
                '^': 正三角形  '3': 左指针
                '<': 左三角形  '4': 右指针
                '>': 右三角形  '+': 加号
                's': 正方形    'x': 叉号
                'p': 五角形    'D': 菱形
                '*': 星形      'd': 瘦菱形
                'h': 六边形1
                'H': 六边形2
    - linestyle :
                '-':实线
                '--':虚线
                '-.':点划线
                ':':点线
    - x_label: x 轴标签
    - y_label: y 轴标签
    - title: 图表标题
    """
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False

    for i in range(len(data_y)):
        plt.plot(data_x[i], data_y[i], color=colors[i], marker=markers[i], linestyle=linestyles[i], label=model_labels[i])

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.show()

    
    plt.savefig('fig.png', dpi=1200)
    print('fig is saved!')

# 示例用法
if __name__ == '__main__':
    
    data_x = [
        [4.1, 5.2, 6.3, 7.4],  # 例如:每行代表同类model在不同设置下的 params 
        [4.5, 6.2, 7.3, 8.4],  
        [6.1, 7.2, 8.3, 7.4],

    ]

    data_y = [
        [70., 75., 80., 90.],  # 例如:每行代表同类模型在不同设置下的 test_acc
        [73., 76., 82., 93.],
        [72., 78., 85., 95.],
    ]

    model_labels = ['Model 1', 'Model 2', 'Model 3'] # 模型名称 , 要跟data_x, data_y每一行对应
    colors = ['#DBC2CF', '#9FA2B2', '#3C7A89']  # 在 https://coolors.co/ 找喜欢的配色
    markers = ['o', 's', '^']  # 符号类型,每个模型对应一个符号
    linestyle = ['-', '--', '--'] # '-':实线(Ours), '--':虚线(others)

    x_label = 'X轴标签  X-label'   # 可以是 params
    y_label = 'Y轴标签 Y-label'    # 可以是 test_acc
    title = '模型性能对比 Performance comparison'

    plot_model_comparison(data_x, data_y, model_labels, colors, markers,linestyle, x_label, y_label, title)
    
    
    # 中文字符 https://blog.csdn.net/fwj_ntu/article/details/105598145
    # 下载 SimHei.ttf 放在 类似以下目录中 '/home/user/miniconda3/envs/xx_envs/lib/python3.9/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf'
    # 修改 '/home/user/miniconda3/envs/xx_envs/lib/python3.9/site-packages/matplotlib/mpl-data/matplotlibrc 文件
    #    #font.family:  sans-serif     -->  取消注释
    #    font.sans-serif: SimHei, ...  -->  在这一行加入SimHei
    # 删除缓存文件 /home/user/.cache/matplotlib/fontlist-xx.json
    # 再次运行 检测是否成功

Output: Alt text