Lab 12 - Neural Networks challenge continuation: data augmentation
In order to keep things interesting, and prevent a good performance being solely due to overfitting, your final evaluation will be done on a perturbed subset of FashionMNIST of size 10.000. Although your scores are good on the unperturbed data, as you have the full data at your disposal, the performances are not that good on a perturbed subset of FashionMNIST. In today’s lab you are expected to work on your previous submission by training the neural network to handle the perturbations in a way that is generalizable.
Augumenting data
When we work with structured forms of data, such as images, we know the output of our neural network (for a classification problem for example) should be invariant and robust to certain perturbations of the data which we know do not change the label. For example, a rotation of 90 degrees of the figure of a shoe should not alter whether the image contains a shoe or not.
As you probably saw from the updated scores, although your neural networks naturally perform well on the original FashionMNIST data, once reasonable perturbations of the same images are done, your results become significanly worse.
From a machine learning perspective this is to be expected: the perturbations add in a distribution shift your model has not seen during training. Nonetheless, if we compare neural networks with humans, this is a form of robustness that humans usually have, and neural networks fail to have. There are broadly speaking three ways of handling this: 1 - Pick the architecture so that it is robust by design (i.e. an architecture which is input-rotation invariant). 2 - Augment your training data so that it contains enough instances of the perturbations so that the model is forced to learn how to handle those in order to generalize well. 3 - Train your neural network in an unsupervised manner in a massive dataset, and hope that it will be forced to learn high-level abstractions which are not directly encoded in the training data, but which the neural network is forced to learn in order to complete the unsupervised task well.
1 is a kind of manual feature engineering, and empirical wisdom says it tends to play out worse off then 2 and 3. That is because, first it’s very hard to know a priori what are all the invariances needed, and second, directly encoding it in the architecture may harm training in unpredictable ways. 2 is the one we will pursue here, and is the standard alternative to 1. You enrich the data and let the neural network figure out how to encode the necessary invariances. 3 is what drives the current trend of LLMs and foundations models. Instead of training neural network \(i\) on dataset \(i\) for task \(i\) for \(i=1,2,...,k\) you just lump all datasets together, train a big neural networks to do some unsupervised task on it, and then fine-tune the big neural network for the specific tasks. The idea is that the more data you feed, even if apparently unrelated, the more robust and powerful the learned representations will be.
import matplotlib.pyplot as pltimport numpy as npimport torchdata = torch.load("../../datasets/fashionmnist/perturbed_train_compressed.pt")images = data['images'].float() /255.0labels = data['labels']# FashionMNIST class namesclasses = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# Visualize a 4x4 gridfig, axes = plt.subplots(4, 4, figsize=(10, 10))indices = np.random.choice(len(images), 16, replace=False)for i, idx inenumerate(indices): ax = axes[i //4, i %4] img = images[idx].squeeze().numpy() ax.imshow(img, cmap='gray') ax.set_title(classes[labels[idx].item()]) ax.axis('off')plt.tight_layout()plt.show()
This week’s challenge
Use the augmented fashionMNIST data containing the perturbations above. Take your previous model and fine-tune it on the perturbed data so that it performs well on a random subset (sized 10,000) of the original data with an independently generated set of perturbations like the ones in the routine above.
You may try to manually generate more augmented data using the perturbations in the perturbed training dataset as a template. But you should be able to get a reasonable score just using this augmented training data we are providing.