Calculer l'information mutuelle avec python

Published: 09 février 2015

DMCA.com Protection Status

Exemple de comment calculer l'information mutuelle en passant par python. (l'objectif ici est de développer un filtre pour sélectionner un ensemble de caractéristiques dans le cadre de l'apprentissage automatique voir: An Introduction to Variable and Feature Selection). Note: le module python scikit learn permet aussi de calculer l;information mutuelle (voir).

Exemple 1: Calculer l'information mutuelle avec python
Exemple 1: Calculer l'information mutuelle avec python

Exemple 2: Calculer l'information mutuelle avec python
Exemple 2: Calculer l'information mutuelle avec python

from random import gauss

import numpy as np
import matplotlib.pyplot as plt

import math

c1_color = (0.69411766529083252, 0.3490196168422699, 0.15686275064945221, 1.0)
c2_color = (0.65098041296005249, 0.80784314870834351, 0.89019608497619629, 1.0)

#----- Data -----#

mean = [0,0]
cov = [[10,10],[20,50]]
x1_c1,x2_c1 = np.random.multivariate_normal(mean,cov,5000).T

mean = [20,0]
cov = [[10,10],[20,50]]
x1_c2,x2_c2 = np.random.multivariate_normal(mean,cov,5000).T

#----- Mutual Information Calculation for x1 -----#

dim_x1_c1 = x1_c1.shape[0]
dim_x1_c2 = x1_c2.shape[0]

Nb_Data = dim_x1_c1 + dim_x1_c2

print x1_c1.shape

data = np.zeros((dim_x1_c1+dim_x1_c2,2))

for i in range(dim_x1_c1):
    data[i,0] = x1_c1[i]
    data[i,1] = 1.0 
for i in range(dim_x1_c2):
    data[dim_x1_c1+i,0] = x1_c2[i]
    data[dim_x1_c1+i,1] = 2.0

print data[:,0].min()
print data[:,1].min()
min_x = data[:,0].min()

print data[:,0].max()
print data[:,1].max()
max_x = data[:,0].max()

Bin_Width = (max_x-min_x) / 40.0 
Nb_Bin = (max_x-min_x) / Bin_Width
print Nb_Bin

Discrete_Probability_x = np.zeros((Nb_Bin))
Discrete_Probability_c = np.zeros((2))
Discrete_Joint_Probability_x = np.zeros((Nb_Bin,2))

for i in range(Nb_Data):
#for i in range(20):
    bin = int(  ( data[i,0] - min_x ) / Bin_Width )
    if bin == Nb_Bin:
        bin = Nb_Bin - 1
    #if bin > Nb_Bin + 1:
    #   print 'bin ', bin   
    #print bin  
    Discrete_Probability_x[bin] = Discrete_Probability_x[bin] + 1
    if( data[i,1] == 1 ):
        Discrete_Probability_c[0] = Discrete_Probability_c[0] + 1
        Discrete_Joint_Probability_x[bin,0] = Discrete_Joint_Probability_x[bin,0] + 1
    else:
        Discrete_Probability_c[1] = Discrete_Probability_c[1] + 1
        Discrete_Joint_Probability_x[bin,1] = Discrete_Joint_Probability_x[bin,1] + 1

Discrete_Probability_x /= Nb_Data
Discrete_Probability_c /= Nb_Data
Discrete_Joint_Probability_x /= Nb_Data

print Discrete_Probability_x
print Discrete_Probability_c

MI = 0
for i in range(int(Nb_Bin)):
    for j in range(2):
        a = Discrete_Joint_Probability_x[i,j]
        b = Discrete_Probability_x[i] * Discrete_Probability_c[j]
        if a > 0.0 and b > 0.0:
            MI = MI + a * math.log( a / b )

print math.log( 1.0 )   
print math.log( math.e )
print math.log1p( math.e )
print 'Mututal Information', MI

#----- Histogram 1 -----#

f = plt.figure()
ax = f.add_subplot(111)

min = x1_c1.min()
if x1_c2.min() < min:
    min = x1_c2.min()

max = x1_c1.max()
if x1_c2.max() > max:
    max = x1_c2.max()

width = 0.5*(max-min)/50
hist1 = np.histogram(x1_c1, bins=np.linspace(min, max, 50),  normed=1)
plt.bar(hist1[1][:-1], hist1[0], width=width,color=c1_color, label='Class 1')
hist2 = np.histogram(x1_c2, bins=np.linspace(min, max, 50),  normed=1)
plt.bar(hist1[1][:-1]+width, hist2[0], width=width,color=c2_color, label='Class 2')

plt.xlabel('x1')
plt.legend()
plt.text(0.1, 0.7,'I = ' + str(round(MI,2)), transform = ax.transAxes)

plt.savefig('Histogram_1.png')
#plt.show()
plt.close()

#----- Mutual Information Calculation for x2 -----#

dim_x2_c1 = x2_c1.shape[0]
dim_x2_c2 = x2_c2.shape[0]

Nb_Data = dim_x2_c1 + dim_x2_c2

data = np.zeros((dim_x2_c1+dim_x2_c2,2))

for i in range(dim_x2_c1):
    data[i,0] = x2_c1[i]
    data[i,1] = 1.0 
for i in range(dim_x2_c2):
    data[dim_x2_c1+i,0] = x2_c2[i]
    data[dim_x2_c1+i,1] = 2.0

print data[:,0].min()
print data[:,1].min()
min_x = data[:,0].min()

print data[:,0].max()
print data[:,1].max()
max_x = data[:,0].max()

Bin_Width = (max_x-min_x) / 40.0 
Nb_Bin = (max_x-min_x) / Bin_Width
print Nb_Bin

Discrete_Probability_x = np.zeros((Nb_Bin))
Discrete_Probability_c = np.zeros((2))
Discrete_Joint_Probability_x = np.zeros((Nb_Bin,2))

for i in range(Nb_Data):
#for i in range(20):
    bin = int(  ( data[i,0] - min_x ) / Bin_Width )
    if bin == Nb_Bin:
        bin = Nb_Bin - 1
    #if bin > Nb_Bin + 1:
    #   print 'bin ', bin   
    #print bin  
    Discrete_Probability_x[bin] = Discrete_Probability_x[bin] + 1
    if( data[i,1] == 1 ):
        Discrete_Probability_c[0] = Discrete_Probability_c[0] + 1
        Discrete_Joint_Probability_x[bin,0] = Discrete_Joint_Probability_x[bin,0] + 1
    else:
        Discrete_Probability_c[1] = Discrete_Probability_c[1] + 1
        Discrete_Joint_Probability_x[bin,1] = Discrete_Joint_Probability_x[bin,1] + 1

Discrete_Probability_x /= Nb_Data
Discrete_Probability_c /= Nb_Data
Discrete_Joint_Probability_x /= Nb_Data

print Discrete_Probability_x
print Discrete_Probability_c

MI = 0
for i in range(int(Nb_Bin)):
    for j in range(2):
        a = Discrete_Joint_Probability_x[i,j]
        b = Discrete_Probability_x[i] * Discrete_Probability_c[j]
        if a > 0.0 and b > 0.0:
            MI = MI + a * math.log( a / b )

print math.log( 1.0 )   
print math.log( math.e )
print math.log1p( math.e )
print 'Mututal Information', MI

#----- Histogram 2 -----#

f = plt.figure()
ax = f.add_subplot(111)

min = x2_c1.min()
if x2_c2.min() < min:
    min = x2_c2.min()

max = x2_c1.max()
if x2_c2.max() > max:
    max = x2_c2.max()

width = 0.5*(max-min)/50
hist1 = np.histogram(x2_c1, bins=np.linspace(min, max, 50),  normed=1)
plt.bar(hist1[1][:-1], hist1[0], width=width,color=c1_color, label='Class 1')
hist2 = np.histogram(x2_c2, bins=np.linspace(min, max, 50),  normed=1)
plt.bar(hist1[1][:-1]+width, hist2[0], width=width,color=c2_color, label='Class 2')

plt.xlabel('x2')
plt.legend()
plt.text(0.2, 0.7,'I = ' + str(round(MI,3)), transform = ax.transAxes)

plt.savefig('Histogram_2.png')
#plt.show()
plt.close()

Recherches associées

Image

of