Comment calculer une erreur quadratique moyenne en python ?

Exemple de comment calculer l’erreur quadratique moyenne en python dans le cas d'un modèle de régression linéaire simple:

\begin{equation}
y = \theta_1 x + \theta_0
\end{equation}

Tracer les données

Générons un ensemble de données aléatoirment suivant:

\begin{equation}
y = 3x + 2
\end{equation}

import matplotlib.pyplot as plt
import numpy as np

X = 4 * np.random.rand(1000,1)
X_b = np.c_[np.ones((1000,1)), X]

Y = 2 + 3 * X + np.random.randn(1000,1)

plt.plot(X,Y,'.')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_01.png", bbox_inches='tight')

Comment calculer une erreur quadratique moyenne en python ?
Comment calculer une erreur quadratique moyenne en python ?

Modèle linéaire

Considérons le modèle linéaire suivant

\begin{equation}
y = \theta_1 x + \theta_0
\end{equation}

avec $\theta_0=-1.4$ et $\theta_1=5.0$

#----- Let's take one random linear model

theta = np.array([[-1.4],[5.0]])

X_new = np.array([[0],[4]])
X_new_b = np.c_[np.ones((2,1)), X_new]

plt.plot(X_new, X_new_b.dot( theta ), '-')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_02.png", bbox_inches='tight')

plt.close()

Comment calculer une erreur quadratique moyenne en python ?
Comment calculer une erreur quadratique moyenne en python ?

Calculer l'erreur quadratique moyenne

On peut alors calculer l'erreur quadratique moyenne en utilisant python:

\begin{equation}
mse = \frac{1}{m} \sum_{i=1}^{m}(\theta^T.\textbf{x}^{(i)}-y^{(i)})^2
\end{equation}

Y_predict = X_b.dot( theta )

print(Y_predict.shape, X_b.shape, theta.shape)

mse = np.sum( (Y_predict-Y)**2 ) / 1000.0

print('mse: ', mse)

On peut aussi utiliser le module python sklearn:

from sklearn.metrics import mean_squared_error

print('mse (sklearn): ', mean_squared_error(Y,Y_predict))

donne par exemple:

mse:  6.75308540424
mse (sklearn):  6.75308540424

Calculer l'erreur quadratique moyenne pour un ensemble de modèles

Exemple de comment calculer l'erreur quadratique moyenne pour une grille de $\theta_0$ et $\theta_1$

#----- Calculate the mse using a grid search

theta_0, theta_1 = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))

theta = np.vstack((theta_0.ravel(), theta_1.ravel()))

Y_predict = X_b @ theta

mse = np.sum( (Y_predict-Y)**2, axis=0 ) / 1000.0

mse = mse.reshape(100,100)

from matplotlib.colors import LogNorm
from pylab import figure, cm

plt.imshow(mse, origin='lower', norm=LogNorm(), extent=[0,10,0,10], cmap=cm.jet)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_0$',fontsize=8)
plt.ylabel(r'$\theta_1$',fontsize=8)

plt.savefig("mean_squared_error_03.png", bbox_inches='tight')

#plt.show()

plt.close()

On constate bien que l'erreur quadratique moyenne minimum est obtenue pour un modèle linéaire avec $\theta_0$ et $\theta_1$ autour de 2 et 3 respectivement.

Comment calculer une erreur quadratique moyenne en python ?
Comment calculer une erreur quadratique moyenne en python ?

On peut aussi tracer l'erreur quadratique moyenne en fonction de $\theta_1$ uniquement pour un $\theta_0$ fixé:

plt.plot(mse[:,20])

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_1$',fontsize=8)
plt.ylabel(r'mean square error',fontsize=8)

positions = [i*10 for i in range(10)]
labels = [i for i in range(10)]

plt.xticks(positions, labels)

plt.grid(linestyle='--')

plt.savefig("mean_squared_error_04.png", bbox_inches='tight')

#plt.show()

Comment calculer une erreur quadratique moyenne en python ?
Comment calculer une erreur quadratique moyenne en python ?

Code source

import matplotlib.pyplot as plt
import numpy as np

X = 4 * np.random.rand(1000,1)
X_b = np.c_[np.ones((1000,1)), X]

Y = 2 + 3 * X + np.random.randn(1000,1)

plt.plot(X,Y,'.')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_01.png", bbox_inches='tight')

#----- Let's take one random linear model

theta = np.array([[-1.4],[5.0]])

X_new = np.array([[0],[4]])
X_new_b = np.c_[np.ones((2,1)), X_new]

plt.plot(X_new, X_new_b.dot( theta ), '-')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_02.png", bbox_inches='tight')

plt.close()

#----- using python

Y_predict = X_b.dot( theta )

print(Y_predict.shape, X_b.shape, theta.shape)

mse = np.sum( (Y_predict-Y)**2 ) / 1000.0

print('mse: ', mse)

#----- using sklearn

from sklearn.metrics import mean_squared_error

print('mse (sklearn): ', mean_squared_error(Y,Y_predict))

#----- Calculate the mse using a grid search

theta_0, theta_1 = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))

theta = np.vstack((theta_0.ravel(), theta_1.ravel()))

Y_predict = X_b @ theta

mse = np.sum( (Y_predict-Y)**2, axis=0 ) / 1000.0

mse = mse.reshape(100,100)

from matplotlib.colors import LogNorm
from pylab import figure, cm

plt.imshow(mse, origin='lower', norm=LogNorm(), extent=[0,10,0,10], cmap=cm.jet)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_0$',fontsize=8)
plt.ylabel(r'$\theta_1$',fontsize=8)

plt.savefig("mean_squared_error_03.png", bbox_inches='tight')

#plt.show()

plt.close()

#----- plot theta_1 for a given theta_0

plt.plot(mse[:,20])

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_1$',fontsize=8)
plt.ylabel(r'mean square error',fontsize=8)

positions = [i*10 for i in range(10)]
labels = [i for i in range(10)]

plt.xticks(positions, labels)

plt.grid(linestyle='--')

plt.savefig("mean_squared_error_04.png", bbox_inches='tight')

#plt.show()

Références

Image

of