Dans cet article on va voir un simple exemple sur comment définir un modèle de mélanges gaussiens (ou GMM pour Gaussian Mixture Model) en utilisant le module scikit de python.
1: Créer des données pour l'exemple
import numpy as npnp.random.seed(1)obs = np.concatenate((np.random.randn(1000, 1), \10+np.random.randn(3000, 1)))
2: Tracer histogramme avec matplotilb
import matplotlib.pyplot as pltnum_bins = 50n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)plt.savefig("Data.png")
3: Utiliser scikit learn
from sklearn import mixtureg = mixture.GMM(n_components=2)g.fit(obs)weights = g.weights_means = g.means_covars = g.covars_print round(weights[0],2), round(weights[1],2)print round(means[0],2), round(means[1],2)print round(covars[0],2), round(covars[1],2)
4: Tracer les gaussiennnes avec matplotlib
import mathimport scipy.stats as statsD = obs.ravel()xmin = D.min()xmax = D.max()x = np.linspace(xmin,xmax,1000)mean = means[0]sigma = math.sqrt(covars[0])plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')mean = means[1]sigma = math.sqrt(covars[1])plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')plt.savefig("DataGMM.png")plt.show()
Script python complet:


import numpy as npnp.random.seed(1)obs = np.concatenate((5*np.random.randn(1000, 1), \6+2*np.random.randn(3000, 1)))#----------#import matplotlib.pyplot as pltnum_bins = 50n, bins, patches = plt.hist(obs, num_bins, normed=1, facecolor='green', alpha=0.5)plt.savefig("Data.png")#----------#from sklearn import mixtureg = mixture.GMM(n_components=2)g.fit(obs)weights = g.weights_means = g.means_covars = g.covars_print round(weights[0],2), round(weights[1],2)print round(means[0],2), round(means[1],2)print round(covars[0],2), round(covars[1],2)#----------#import mathimport scipy.stats as statsD = obs.ravel()xmin = D.min()xmax = D.max()x = np.linspace(xmin,xmax,1000)mean = means[0]sigma = math.sqrt(covars[0])plt.plot(x,weights[0]*stats.norm.pdf(x,mean,sigma), c='red')mean = means[1]sigma = math.sqrt(covars[1])plt.plot(x,weights[1]*stats.norm.pdf(x,mean,sigma), c='blue')plt.savefig("DataGMM.png")plt.show()
Recherches associées
| Liens | Site |
|---|---|
| scikit-learn | scikit-learn Main Page |
| Modèle de mélanges gaussiens | wikipedia |
| sklearn.mixture.GMM | scikit-learn |
| Les mélanges gaussiens | bioinfo-fr |
| How can I plot the probability density function for a fitted Gaussian mixture model under scikit-learn? | stackoverflow |
| Gaussian Mixture Models (GMM) and the K-Means Algorithm | cse |
| Separate mixture of gaussians in Python | stackoverflow |
| numpy.histogram | scipy doc |
| Plot Normal distribution with Matplotlib | stackoverflow |
| python pylab plot normal distribution | stackoverflow |
| Plotting of 1-dimensional Gaussian distribution function | stackoverflow |
| scipy.stats.norm | scipy doc |
