Exemple de comment tracer une matrice de confusion avec matplotlib:
# -*- coding: utf-8 -*-
# Source: http://azaleasays.com/2010/04/29/matplotlib-example-color-mesh/
from numpy import *
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import math
font = {'size' : 16}
matplotlib.rc('font', **font)
def sqrt_sym(x):
'''A function to scale the colormap for better definition at both ends.'''
sqrt_sym = math.sqrt(x*2-1) if x > 0.5 else -math.sqrt(1-x*2)
return (sqrt_sym+1)/2
def cmap_xmap(function,cmap):
''' Applies function, on the indices of colormap cmap. Beware, function
should map the [0, 1] segment to itself, or you are in for surprises.
Third-party function. Source:
http://www.scipy.org/Cookbook/Matplotlib/ColormapTransformations
'''
cdict = cmap._segmentdata
function_to_map = lambda x : (function(x[0]), x[1], x[2])
for key in ('red','green','blue'):
cdict[key] = map(function_to_map, cdict[key])
cdict[key].sort()
'''print cdict'''
assert (cdict[key][0]<0 or cdict[key][-1]>1),\
'Resulting indices extend out of the [0, 1] segment.'
return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)
def set_xtick(ax):
plt.xticks(np.arange(0.5,3.5,1), (u'CE 1', u'CE 2', u'CE 3') )
plt.setp([ax.get_xticklabels()[0],ax.get_xticklabels()[1],ax.get_xticklabels()[2]], rotation=45,color = 'k')
def set_ytick(ax):
plt.yticks(np.arange(0.5,3.5,1), (u'CR 1', u'CR 2', u'CR 3') )
plt.setp([ax.get_yticklabels()[0],ax.get_yticklabels()[1],ax.get_yticklabels()[2]], rotation=0, color = 'k')
def autolabel(arrayA):
''' label each colored square with the corresponding data value.
If value > 20, the text is in black, else in white.
'''
for i in range(3):
for j in range(3):
if 20 < arrayA[i,j] < 100:
plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='k')
else:
plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='w')
mymap = cmap_xmap(sqrt_sym,plt.cm.jet)
plotData = [240.0, 54.0, 13.0,
35.0, 320.0, 45.0,
75.0, 74.0, 220.0]
Normalization = 100.0 / sum(plotData)
plotData = [i*Normalization for i in plotData]
plotArray = np.array(plotData)
plotArray = plotArray.reshape(3,3)
fig = plt.figure()
ax = fig.add_subplot(111)
#mymap = cmap_xmap(sqrt_sym,plt.cm.jet)
plt.pcolormesh(plotArray,cmap=mymap,vmin=0,vmax=100)
set_xtick(ax)
set_ytick(ax)
ax.set_ylim(0.0, 3.0)
autolabel(plotArray)
fig.subplots_adjust(bottom=0.27)
fig.subplots_adjust(left=0.27)
plt.title( 'Titre' )
plt.colorbar(orientation='vertical')
plt.savefig('ConfusionTable.png', bbox_inches='tight')
Recherches associées
Liens | Site |
---|---|
Matrice de confusion | wikipedia |
Confusion matrix | wikipedia |
matplotlib color mesh | azaleasays |