Dans cet article on va voir un simple exemple sur comment définir un modèle de mélanges gaussiens (ou GMM pour Gaussian Mixture Model) en utilisant le module scikit de python.
1: Créer des données pour l'exemple
import numpy as np
np.random.seed(1)
obs = np.concatenate((np.random.randn(1000, 1), \
10+np.random.randn(3000, 1)))
2: Tracer histogramme avec matplotilb
import matplotlib.pyplot as plt
num_bins = 50
n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)
plt.savefig("Data.png")
3: Utiliser scikit learn
from sklearn import mixture
g = mixture.GMM(n_components=2)
g.fit(obs)
weights = g.weights_
means = g.means_
covars = g.covars_
print round(weights[0],2), round(weights[1],2)
print round(means[0],2), round(means[1],2)
print round(covars[0],2), round(covars[1],2)
4: Tracer les gaussiennnes avec matplotlib
import math
import scipy.stats as stats
D = obs.ravel()
xmin = D.min()
xmax = D.max()
x = np.linspace(xmin,xmax,1000)
mean = means[0]
sigma = math.sqrt(covars[0])
plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')
mean = means[1]
sigma = math.sqrt(covars[1])
plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')
plt.savefig("DataGMM.png")
plt.show()
Script python complet:
import numpy as np
np.random.seed(1)
obs = np.concatenate((5*np.random.randn(1000, 1), \
6+2*np.random.randn(3000, 1)))
#----------#
import matplotlib.pyplot as plt
num_bins = 50
n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)
plt.savefig("Data.png")
#----------#
from sklearn import mixture
g = mixture.GMM(n_components=2)
g.fit(obs)
weights = g.weights_
means = g.means_
covars = g.covars_
print round(weights[0],2), round(weights[1],2)
print round(means[0],2), round(means[1],2)
print round(covars[0],2), round(covars[1],2)
#----------#
import math
import scipy.stats as stats
D = obs.ravel()
xmin = D.min()
xmax = D.max()
x = np.linspace(xmin,xmax,1000)
mean = means[0]
sigma = math.sqrt(covars[0])
plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')
mean = means[1]
sigma = math.sqrt(covars[1])
plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')
plt.savefig("DataGMM.png")
plt.show()
Recherches associées
Liens | Site |
---|---|
scikit-learn | scikit-learn Main Page |
Modèle de mélanges gaussiens | wikipedia |
sklearn.mixture.GMM | scikit-learn |
Les mélanges gaussiens | bioinfo-fr |
How can I plot the probability density function for a fitted Gaussian mixture model under scikit-learn? | stackoverflow |
Gaussian Mixture Models (GMM) and the K-Means Algorithm | cse |
Separate mixture of gaussians in Python | stackoverflow |
numpy.histogram | scipy doc |
Plot Normal distribution with Matplotlib | stackoverflow |
python pylab plot normal distribution | stackoverflow |
Plotting of 1-dimensional Gaussian distribution function | stackoverflow |
scipy.stats.norm | scipy doc |