Exemple de comment tracer une matrice de confusion avec matplotlib:

# -*- coding: utf-8 -*-# Source: http://azaleasays.com/2010/04/29/matplotlib-example-color-mesh/from numpy import *import matplotlibimport numpy as npimport matplotlib.pyplot as pltimport mathfont = {'size' : 16}matplotlib.rc('font', **font)def sqrt_sym(x):'''A function to scale the colormap for better definition at both ends.'''sqrt_sym = math.sqrt(x*2-1) if x > 0.5 else -math.sqrt(1-x*2)return (sqrt_sym+1)/2def cmap_xmap(function,cmap):''' Applies function, on the indices of colormap cmap. Beware, functionshould map the [0, 1] segment to itself, or you are in for surprises.Third-party function. Source:http://www.scipy.org/Cookbook/Matplotlib/ColormapTransformations'''cdict = cmap._segmentdatafunction_to_map = lambda x : (function(x[0]), x[1], x[2])for key in ('red','green','blue'):cdict[key] = map(function_to_map, cdict[key])cdict[key].sort()'''print cdict'''assert (cdict[key][0]<0 or cdict[key][-1]>1),\'Resulting indices extend out of the [0, 1] segment.'return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)def set_xtick(ax):plt.xticks(np.arange(0.5,3.5,1), (u'CE 1', u'CE 2', u'CE 3') )plt.setp([ax.get_xticklabels()[0],ax.get_xticklabels()[1],ax.get_xticklabels()[2]], rotation=45,color = 'k')def set_ytick(ax):plt.yticks(np.arange(0.5,3.5,1), (u'CR 1', u'CR 2', u'CR 3') )plt.setp([ax.get_yticklabels()[0],ax.get_yticklabels()[1],ax.get_yticklabels()[2]], rotation=0, color = 'k')def autolabel(arrayA):''' label each colored square with the corresponding data value.If value > 20, the text is in black, else in white.'''for i in range(3):for j in range(3):if 20 < arrayA[i,j] < 100:plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='k')else:plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='w')mymap = cmap_xmap(sqrt_sym,plt.cm.jet)plotData = [240.0, 54.0, 13.0,35.0, 320.0, 45.0,75.0, 74.0, 220.0]Normalization = 100.0 / sum(plotData)plotData = [i*Normalization for i in plotData]plotArray = np.array(plotData)plotArray = plotArray.reshape(3,3)fig = plt.figure()ax = fig.add_subplot(111)#mymap = cmap_xmap(sqrt_sym,plt.cm.jet)plt.pcolormesh(plotArray,cmap=mymap,vmin=0,vmax=100)set_xtick(ax)set_ytick(ax)ax.set_ylim(0.0, 3.0)autolabel(plotArray)fig.subplots_adjust(bottom=0.27)fig.subplots_adjust(left=0.27)plt.title( 'Titre' )plt.colorbar(orientation='vertical')plt.savefig('ConfusionTable.png', bbox_inches='tight')
Recherches associées
| Liens | Site |
|---|---|
| Matrice de confusion | wikipedia |
| Confusion matrix | wikipedia |
| matplotlib color mesh | azaleasays |
