Recognize Class Imbalance with Baselines and Better Metrics

Editor’s Note: Samuel is speaking at ODSC West 2019, see his talk “Help! My Classes are Imbalanced” there.

In my first machine learning course as an undergrad, I built a recommender system. Using a dataset from a social music website, I created a model to predict whether a given user would like a given artist. I was thrilled when initial experiments showed that for 99% of the points in my dataset, I gave the correct rating – I was wrong only 1% of the time!

When I proudly shared the results with my professor, he revealed that I wasn’t, in fact, a machine learning prodigy. I’d made a mistake called the base rate fallacy. The dataset I used exhibited a high degree of class imbalance. In other words, for 99% of the pairs between user and artist, the user did not like the artist. This makes sense: there are many, many musicians in the world, and it’s unlikely that one person has even heard of half of them (let alone actually enjoys them).

When we’re unprepared for it, class imbalance introduces problems by producing misleading metrics. The undergrad version of me ran face-first into this problem: accuracy alone tells us almost nothing. A trivial model that predicts that no users like any artists can achieve 99% accuracy, but it’s completely worthless. Using accuracy as a metric assumes that all errors are equally costly; this is frequently not the case.

Consider a medical example. If we incorrectly classify a tumor as malignant and request further screening, the cost of that error is a worry for the patient and time for the hospital workers. By contrast, if we incorrectly state that a tumor is benign when it is in fact malignant, the patient may die.

[Related article: Strategies for Addressing Class Imbalance]

Examine the distribution of classes

Moving beyond accuracy, there are a number of metrics to think about in an imbalanced problem. Knowing the distribution of classes is the first line of defense. As a rule of thumb, Prati, Batista, and Silva find that class imbalance doesn’t significantly harm performance in cases where the minority class makes up 10% or more of the dataset. If you find that your dataset is imbalanced more than this, pay special attention.

I recommend starting with an incredibly simple model: pick the most frequent class. scikit-learn implements this in the DummyClassifier. Had I done this with my music recommendation project, I would quickly have noticed that my fancy model wasn’t really learning anything.

Evaluate the cost

In an ideal world, we could calculate the exact costs of a false negative and a false positive. When evaluating our models, we could multiply those costs by the false negative and false positive rates to come up with a number that describes the cost of our model. Unfortunately, these costs are often unknown in the real world, and improving the false positive rate usually harms the true positive rate.

To visualize this tradeoff, we can use an ROC curve. Most classifiers can output probability of membership in a certain class. If we choose a threshold (50%, for example), we can declare that all points with probability over the threshold are members of the positive class. Varying the threshold from a low percentage to a high percentage produces different ways of classifying points that have different true positive and false-positive rates. Plotting the false positive rate on the x-axis and the true positive rate on the y-axis, we get an ROC curve.

As an example, I trained a classifier on the yeast3 dataset from KEEL and created an ROC curve:

ROC for LogisticRegression on yeast3

While we could certainly write the code to draw an ROC curve, the yellowbrick library has this capability built-in (and it’s compatible with scikit-learn models). These curves can suggest where to set the threshold for our model. Further, we can use the area under them to compare multiple models (though there are times when this isn’t a good metric).

The next time you’re working on a machine learning problem, consider the distribution of the target variable. A huge first step towards solving class imbalance is recognizing the problem. By using better metrics and visualizations, we can start to talk about imbalanced problems much more clearly.

[Related article: Watch: Introduction to Machine Learning with Scikit-Learn]

More on class imbalance

In my upcoming talk at ODSC West, I’ll dive deeper into the causes of class imbalance. I’ll also explore different ways to address this error. I hope to see you in October!