Modèle de mélanges gaussiens (GMM 1d) avec python

Published: 11 mars 2015

DMCA.com Protection Status

Dans cet article on va voir un simple exemple sur comment définir un modèle de mélanges gaussiens (ou GMM pour Gaussian Mixture Model) en utilisant le module scikit de python.

1: Créer des données pour l'exemple

import numpy as np

np.random.seed(1)
obs = np.concatenate((np.random.randn(1000, 1), \
                      10+np.random.randn(3000, 1)))

2: Tracer histogramme avec matplotilb

import matplotlib.pyplot as plt

num_bins = 50
n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)
plt.savefig("Data.png")

3: Utiliser scikit learn

from sklearn import mixture

g = mixture.GMM(n_components=2)

g.fit(obs)

weights = g.weights_
means = g.means_
covars = g.covars_


print round(weights[0],2), round(weights[1],2)
print round(means[0],2), round(means[1],2)
print round(covars[0],2), round(covars[1],2)

4: Tracer les gaussiennnes avec matplotlib

import math
import scipy.stats as stats

D = obs.ravel()
xmin = D.min()
xmax = D.max()
x = np.linspace(xmin,xmax,1000)

mean = means[0]
sigma = math.sqrt(covars[0])
plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')

mean = means[1]
sigma = math.sqrt(covars[1])
plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')

plt.savefig("DataGMM.png")
plt.show()

Script python complet:

Modèle de mélanges gaussiens (GMM 1d) avec python (Data.png)
Modèle de mélanges gaussiens (GMM 1d) avec python (Data.png)

Modèle de mélanges gaussiens (GMM 1d) avec python (DataGMM.png)
Modèle de mélanges gaussiens (GMM 1d) avec python (DataGMM.png)

import numpy as np

np.random.seed(1)
obs = np.concatenate((5*np.random.randn(1000, 1), \
                      6+2*np.random.randn(3000, 1)))

#----------#

import matplotlib.pyplot as plt

num_bins = 50
n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)
plt.savefig("Data.png")

#----------#

from sklearn import mixture

g = mixture.GMM(n_components=2)

g.fit(obs)

weights = g.weights_
means = g.means_
covars = g.covars_


print round(weights[0],2), round(weights[1],2)
print round(means[0],2), round(means[1],2)
print round(covars[0],2), round(covars[1],2)

#----------#

import math
import scipy.stats as stats

D = obs.ravel()
xmin = D.min()
xmax = D.max()
x = np.linspace(xmin,xmax,1000)

mean = means[0]
sigma = math.sqrt(covars[0])
plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')

mean = means[1]
sigma = math.sqrt(covars[1])
plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')

plt.savefig("DataGMM.png")
plt.show()

Recherches associées

Image

of