cft

Finding Hard Samples in Your Image Classification Dataset

Mining hard samples is an efficient way to improve a machine learning dataset, this guide will show you how to do it yourself


user

Eric Hofesmann

3 years ago | 8 min read

Say you have a repository of data with millions of unlabeled images in it. You managed to label a subset of data and trained an image classification model on it, but it’s not performing as well as you hope. How do you decide which new samples to annotate and add to your training set?

You could just randomly select new samples to annotate, but there is a better way. Hard sample mining is a tried and true method to distill a large amount of raw unlabeled data into smaller high quality labeled datasets.

A hard sample is one where your machine learning (ML) model finds it difficult to correctly predict the label.

In an image classification dataset, a hard sample could be anything from a cat that looks like a dog to a blurry resolution image. If you expect your model to perform well on these hard samples, then you may need to “mine” more examples of these hard samples to add to your training dataset.

Exposing your model to more hard samples during training will allow it to perform better on those types of samples later on.

Hard samples are useful for more than just training data, they are also necessary to include in your test set. If your test data is composed primarily of easy samples, then your performance will soon reach an upper bound causing progress to stagnate. Adding hard samples to a test set will give you a better idea of how models perform in harder edge cases and can provide more insight into which models are more reliable.

Follow along in your browser with this Colab notebook!

This walkthrough running in Colab (ipynb link here) (Image by author)
This walkthrough running in Colab (ipynb link here) (Image by author)

Overview

This post will walk you through how to use the new open-source ML tool that I have been working on, FiftyOne, to find hard samples in your dataset. In the spirit of making this walkthrough easy to follow, we will use an existing model and dataset. Namely, ResNet50 to find hard samples in the CIFAR-10 dataset test split.

It will walk you through how to:

  • Load your image dataset into FiftyOne
  • Add logits from a pretrained network to your dataset
  • Use FiftyOne to calculate the hardness of every sample
  • Explore your dataset and find the hardest and easiest samples

Setup

This walkthrough will be using PyTorch and FiftyOne as well as a model from PyTorch_CIFAR10 on GitHub. The install instructions for PyTorch and FiftyOne are simple:

pip install torch torchvision
pip install fiftyone

Load your data

For this example, we will be using the test split of the image classification dataset, CIFAR-10. This dataset contains 10,000 test images labeled across 10 different classes. This is one of the dozens of datasets in the FiftyOne Dataset Zoo, so we can easily load it up.

import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset("cifar10", split="test")

We can use the FiftyOne App to take a look at this dataset.

import fiftyone as fo

session = fo.launch_app(dataset)

CIFAR-10 and ground truth labels visualized in the FiftyOne App (Image by author)

Note: You can also load your own dataset into FiftyOne. It supports labels for many computer vision tasks including classification, detection, segmentation, keypoints, and more. For example, if your dataset contains images stored in per-class directories, you can use the following code to load it.

import fiftyone as fo
dataset = fo.Dataset.from_dir("/path/to/dir", dataset_type=fo.types.ImageClassificationDirectoryTree)

Add logits

In order to calculate hardness on images in FiftyOne, you first need to use a model to compute logits for those images. You can use any model you want, but ideally, it would be one trained similar data and on the same task you will be using these new images for.

In this example, we will be using code from the PyTorch CIFAR-10 repository, namely the pretrained ResNet50 classifier.

# Download the software
git clone --depth 1 --branch v2.1 https://github.com/huyvnphan/PyTorch_CIFAR10.git

# Download the pretrained model (90MB)
eta gdrive download --public \
1dGfpeFK_QG0kV-U6QDHMX2EOGXPqaNzu \
PyTorch_CIFAR10/cifar10_models/state_dicts/resnet50.pt

You can easily add a classification field with logits to your samples in a FiftyOne dataset.

import sys

import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader

import fiftyone.utils.torch as fout

sys.path.insert(1, "PyTorch_CIFAR10")
from cifar10_models import resnet50


# Set up a data loader in accordance to PyTorch CIFAR10
def make_cifar10_data_loader(image_paths, sample_ids, batch_size):
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
dataset = fout.TorchImageDataset(
image_paths, sample_ids=sample_ids, transform=transforms
)
return DataLoader(dataset, batch_size=batch_size, num_workers=4)


# Run inference on the model to generate predictions and logits
def predict(model, imgs):
logits = model(imgs).detach().cpu().numpy()
predictions = np.argmax(logits, axis=1)
odds = np.exp(logits)
confidences = np.max(odds, axis=1) / np.sum(odds, axis=1)
return predictions, confidences, logits

#
# Load a pretrained model
#
# Model performance numbers are available at:
# https://github.com/huyvnphan/PyTorch_CIFAR10
#
model = resnet50(pretrained=True)
model_name = "resnet50"

# If using the entire 10,000 sample dataset
view = dataset

#
# NOTE: If you want this to run faster, uncomment the lines below to take a random subset of data
# Select a random subset of 1000 samples
#

# num_samples = 1000
# view = dataset.take(num_samples, seed=51)

batch_size = 20

# Get the list of classes from the dataset information
classes = dataset.info["classes"]

image_paths, sample_ids = zip(
*[(s.filepath, s.id) for s in view.select_fields(["filepath", "id"])]
)

# Create a PyTorch data loader
data_loader = make_cifar10_data_loader(image_paths, sample_ids, batch_size)

#
# Perform prediction and store results in dataset
#
for imgs, sample_ids in data_loader:
predictions, confidences, logits_ = predict(model, imgs)

# Add predictions to your FiftyOne dataset
for sample_id, prediction, confidence, logits in zip(sample_ids, predictions, confidences, logits_):
sample = dataset[sample_id]
sample.tags.append("processed")
sample[model_name] = fo.Classification(
label=classes[prediction], logits=logits, confidence=confidence
)
sample.save()

processed_view = dataset.match_tags(["processed"])

Compute hardness

The FiftyOne Brain contains various useful methods that can provide insights into your data. At the moment, you can compute the uniqueness of your data, the hardest samples, as well as annotation mistakes. These are all different ways to generate scalar metrics on your dataset that will let you better understand the quality of existing data as well as select help high-quality new samples of data.

Once you have loaded your dataset and added logits to your samples, you calculate hardness in one line of code. The hardness algorithm is closed-source, but the basic idea is to leverage the relative uncertainty of the model’s predictions to assign a scalar hardness value to each sample.

Explore and identify the hardest samples

You can visualize your dataset and explore the samples with the highest and lowest hardness scores with the FiftyOne App.

Dataset sorted to show the hardest samples (Image by author)
Dataset sorted to show the hardest samples (Image by author)

While this example is using small images from CIFAR-10, FiftyOne also works with high-resolution images and videos.

We can write some queries to dig a bit deeper into these hardness calculations and how they relate to other aspects of the data. For example, we can see the distribution of hardness on correct and incorrect predictions of the model separately.

Distribution of hardness for correctly and incorrectly predicted samples (Image by author)
Distribution of hardness for correctly and incorrectly predicted samples (Image by author)

As you might expect, the figure above shows that the distribution of hardness for correct predictions skews towards lower hardness values while incorrect predictions are spread more evenly at high hardness values. This indicates that samples that the model predicts incorrectly tend to be harder samples. Thus, adding harder samples to the training set should improve model performance.

We can also see how the hardness of samples is distributed across different classes.

cls_hardness = []

for label in processed_view.distinct("ground_truth.label"):
label_view = processed_view.match(F("ground_truth.label")==label)
avg_hardness = label_view.sum("hardness")/label_view.count()

num_correct = correct_view.match(F("ground_truth.label")==label).count()
accuracy = num_correct/label_view.count()

cls_hardness.append([avg_hardness, label, accuracy])

print("Average classwise hardness\n")
for avg_hardness, label, _ in sorted(cls_hardness, reverse=True):
print("%s: %f" % (label, avg_hardness))


Average classwise hardness

cat: 0.703082
dog: 0.628436
airplane: 0.591202
bird: 0.588827
frog: 0.577954
truck: 0.573330
horse: 0.564832
deer: 0.561707
automobile: 0.554695
ship: 0.553041

It seems that cat and dog tend to be the hardest classes so it would be worthwhile adding more examples of these before other classes.

import matplotlib.pyplot as plt

avg_hardness = [i[0] for i in cls_hardness]
acc = [i[2] for i in cls_hardness]
labels = [i[1] for i in cls_hardness]

plt.scatter(avg_hardness, acc)
plt.xlabel("Hardness")
plt.ylabel("Accuracy")

for i, label in enumerate(labels):
plt.annotate(label, (avg_hardness[i]+.002, acc[i]))

plt.show()
Classwise accuracy versus hardness (Image by author)
Classwise accuracy versus hardness (Image by author)

We can see that there is an anti-correlation between the average hardness of the samples in a class and the accuracy of the model on that class.

Let's take a look at the incorrectly predicted samples of the hardest class, “cat”.

Hardest incorrectly predicted samples of the hardest category “cat” (Image by author)
Hardest incorrectly predicted samples of the hardest category “cat” (Image by author)

Now let’s take a look at the correctly predicted images of cats with the lowest hardness.

Least hard correctly predicted samples of the hardest category “cat” (Image by author)
Least hard correctly predicted samples of the hardest category “cat” (Image by author)

Comparing the hardest incorrectly predicted cat images with the easiest correctly predicted cat images, we can see that the model has a much easier time classifying images of cats faces looking directly at the camera.

The images of cats that the model struggles the most with are ones of cats in poor lighting, complex backgrounds, and poses where they are not sitting and facing the camera. Now we have an idea of the types of cat images to look for to add to this dataset.

What’s next?

This example was done on a previously annotated set of data in order to show how hardness relates to other aspects of a dataset. In a real-world application, you would now apply this method to new unlabeled data.

Once you’ve identified the hardest samples you have available, it’s time to update your dataset. You can select the X samples with the highest hardness value to send off to get annotated and added to your train or test set. Alternatively, you could select samples proportionally to the per-class hardness calculated above.

Retraining your model on this new data should now allow it to perform better on harder edge cases (I’m working on a follow-up blog post to show this). Additionally, adding these samples to your test set will let you be more confident in the ability of your model to perform well on new unseen data if it performs well on your test set.

Upvote


user
Created by

Eric Hofesmann

Machine Learning Engineer at Voxel51 developing tools to help computer vision engineers refine their datasets and improve model performance. https://fiftyone.ai/


people
Post

Upvote

Downvote

Comment

Bookmark

Share


Related Articles