Beyond model metrics: confusion matrix and parity plots

Beyond model metrics: confusion matrix and parity plots

There are many metrics for assessing the “quality” of a machine learning model, depending on whether one is dealing with a regression or a classification task. There is RMSE, MAPE, R2 for regression, for instance, and AUC, ROC scores for classification.

However, I find it very hard and unbelievable that one can only rely on such vague proxies. I have been also faced the situation that the models look great in development, but not so great in practice. This is no surprise, of course, as it is expected that simple descriptive statistics miss features regarding the “shape” of the data. What we observe in practice is then closely similar to Anscombe’s quartet: we might interpret a single-metric as an indication that the model has a completely different performance than it actually has.

A way around this is to look at the complete story: for classification problems, one needs to look at the confusion matrix. This is a one-liner in scikit learn.

The parity plot is its continuous analog. It can be calculated by:

import matplotlib.pyplot as plt y_true_min, y_true_max = y_true.min(), y_true.max() y_pred_min, y_pred_max = y_pred.min(), y_pred.max() lb = min(y_true_min, y_pred_min) ub = max(y_true_max, y_pred_max) fig, ax = <a href="http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.subplots">plt.subplots</a>() ax.scatter(y_true, y_pred) ax.plot([lb, ub], [lb, ub], 'r-', lw=4) ax.set_xlabel('True values') ax.set_ylabel('Predicted') <a href="http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.show">plt.show</a>()
Code language: JavaScript (javascript)

Note that the weird calculation of the lower and upper bounds (lb and ub) comes from the fact that sometimes the predictions might be well off from the model, especially during model development, so it is worth zooming in.