Tracer une matrice de confusion avec matplotlib

Published: 19 février 2015

DMCA.com Protection Status

Exemple de comment tracer une matrice de confusion avec matplotlib:

Tracer une matrice de confusion avec matplotlib
Tracer une matrice de confusion avec matplotlib

# -*- coding: utf-8 -*-

# Source: http://azaleasays.com/2010/04/29/matplotlib-example-color-mesh/

from numpy import *

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import math

font = {'size'   : 16}
matplotlib.rc('font', **font)

def sqrt_sym(x):
    '''A function to scale the colormap for better definition at both ends.'''     
    sqrt_sym = math.sqrt(x*2-1) if x > 0.5 else -math.sqrt(1-x*2)
    return (sqrt_sym+1)/2

def cmap_xmap(function,cmap):
    ''' Applies function, on the indices of colormap cmap. Beware, function
     should map the [0, 1] segment to itself, or you are in for surprises.
     Third-party function. Source:
     http://www.scipy.org/Cookbook/Matplotlib/ColormapTransformations
    '''
    cdict = cmap._segmentdata
    function_to_map = lambda x : (function(x[0]), x[1], x[2])
    for key in ('red','green','blue'):
        cdict[key] = map(function_to_map, cdict[key])
        cdict[key].sort()
        '''print cdict'''
        assert (cdict[key][0]<0 or cdict[key][-1]>1),\
            'Resulting indices extend out of the [0, 1] segment.'
    return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)

def set_xtick(ax):
    plt.xticks(np.arange(0.5,3.5,1), (u'CE 1', u'CE 2', u'CE 3') )
    plt.setp([ax.get_xticklabels()[0],ax.get_xticklabels()[1],ax.get_xticklabels()[2]], rotation=45,color = 'k')

def set_ytick(ax):
    plt.yticks(np.arange(0.5,3.5,1), (u'CR 1', u'CR 2', u'CR 3') )    
    plt.setp([ax.get_yticklabels()[0],ax.get_yticklabels()[1],ax.get_yticklabels()[2]], rotation=0, color = 'k')

def autolabel(arrayA):
    ''' label each colored square with the corresponding data value. 
    If value > 20, the text is in black, else in white.
    '''
    for i in range(3):
        for j in range(3):
            if 20 < arrayA[i,j] < 100:
                plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='k')
            else:
                plt.text(j+0.45,i+0.45, round(arrayA[i,j],1), ha='center', va='bottom',color='w')

mymap = cmap_xmap(sqrt_sym,plt.cm.jet)

plotData = [240.0, 54.0, 13.0,
            35.0, 320.0, 45.0,
            75.0, 74.0, 220.0]

Normalization = 100.0 / sum(plotData)

plotData = [i*Normalization for i in plotData]

plotArray = np.array(plotData)
plotArray = plotArray.reshape(3,3)
fig = plt.figure()
ax = fig.add_subplot(111)
#mymap = cmap_xmap(sqrt_sym,plt.cm.jet)
plt.pcolormesh(plotArray,cmap=mymap,vmin=0,vmax=100)
set_xtick(ax)
set_ytick(ax)
ax.set_ylim(0.0, 3.0)
autolabel(plotArray)
fig.subplots_adjust(bottom=0.27)
fig.subplots_adjust(left=0.27)

plt.title( 'Titre' ) 
plt.colorbar(orientation='vertical')

plt.savefig('ConfusionTable.png', bbox_inches='tight')

Recherches associées

Liens Site
Matrice de confusion wikipedia
Confusion matrix wikipedia
matplotlib color mesh azaleasays
Image

of