In deep learning, transfer learning is the process of using a pre-trained model for a new problem.

In computer vision, we use CNN embeddings to apply transfer learning to a new problem:

  • Train a CNN on a large image dataset (someone else does this!).
  • Remove classification layers and freeze existing convolutional weights.
  • Add new layers at the end of the model suitable for our new task.

Why does this work? Because deep CNNs are able to extract a low-dimensional universal set of visual features. All we need to do from here is retrain the classifier for our specific task. The benefit of this is that the training converges faster, because instead of starting from a random point (wrt to the CNN output), we start from a good point.

The embedding encodes everything needed from the image to classify objects.

Using pre-trained CNN weights, we can output a tensor representation of a given image’s features. Since a project-specific representation usually isn’t necessary, it can save us valuable time we might’ve otherwise spent training.

In code

Software packages can also include pre-trained models. For instance, PyTorch has several pre-trained models.

import torchvision
alexnet = torchvision.models.alexnet(pretrained=True)