Pre-trained models are easy to use, but are you glossing over details that could impact your model performance?
How many times have you run the following snippets:
import torchvision.models as models inception = models.inception_v3(pretrained=
from keras.applications.inception_v3 import InceptionV3 base_model = InceptionV3(weights='imagenet', include_top=False)
It seems like using these pre-trained models have become a new standard for industry best practices. After all, why wouldn’t you take advantage of a model that’s been trained on more data and compute than you could ever muster by yourself?
There are several substantial benefits to leveraging pre-trained models:
- super simple to incorporate
- achieve solid (same or even better) model performance quickly
- there’s not as much labeled data required
- versatile uses cases from transfer learning, prediction, and feature extraction
One common technique for leveraging pretrained models is feature extraction, where you’re retrieving intermediate representations produced by the pretrained model and using those representations as inputs for a new model. These final fully-connected layers are generally assumed to capture information that is relevant for solving a new task.
Everyone’s in on the game
Every major framework like Tensorflow, Keras, PyTorch, MXNet, etc… offers pre-trained models like Inception V3, ResNet, AlexNet with weights:
The article that inspired this post came from Curtis Northcutt, a computer science PhD candidate at MIT.
architectures perform better in PyTorch and inception architectures perform better in Keras
You might be wondering:
How is that possible?
and some interesting insights into the reason for these differences:
Knowing (and trusting) these benchmarks are important because they allow you to make informed decisions around which framework to use and are often used as baselines for research and implementation.
So what are some things to look out for when you’re leveraging these pre-trained models?
1. How similar is your task? How similar is your data?
Are you expecting that cited 0.945% validation accuracy for the Keras Xception model you’re using with your new dataset of x-rays? First, you need to check how similar your data is to the original dataset that the model was trained on (in this case: ImageNet).
You also need to be aware of where the features have been transferred from (the bottom, middle, or top of the network) because that will impact model performance depending on task similarity.
2. How did you preprocess the data?
function for the corresponding model-level module.
# VGG16 keras.applications.vgg16.preprocess_input # InceptionV3 keras.applications.inception_v3.preprocess_input #ResNet50 keras.applications.resnet50.preprocess_input
3. What’s your backend?
Woolf’s post is from 2017, so It’d be interesting to get an updated comparison that also includes Theano and MXNet as a backend (although Theano is now deprecated).
4. What’s your hardware?
Are you using an Amazon EC2 NVIDIA Tesla K80 or a Google Compute NVIDIA Tesla P100? Maybe even a TPU? 😜 Check out these useful benchmark resources for run times for these different pretrained models.
5. What’s your learning rate?
In practice, you should either keep the pre-trained parameters fixed (ie. use the pre-trained models as feature extractors) as or tune them with a fairly small learning in order to not unlearn everything in the original model.
6. Is there a difference in how you use optimizations like batch normalization or dropout, especially between training mode and inference mode?
As Curtis’ post claims:
Keras models using batch normalization can be unreliable. For some models, forward-pass evaluations (with gradients supposedly off) still result in weights changing at inference time. (See 5)
But why is this the case?
The problem with the current implementation of Keras is that when a batch normalization (BN) layer is frozen, it continues to use the mini-batch statistics during training. I believe a better approach when the BN is frozen is to use the moving mean and variance that it learned during training. Why? For the same reasons why the mini-batch statistics should not be updated when the layer is frozen: it can lead to poor results because the next layers are not trained properly.
Vasilis also cited instances where this discrepancy led to significant drops in model performance (“from 100% down to 50% accuracy) when the Keras model is switched from train mode to test mode.
Use these questions to guide how you interact with pre-trained models for your next project.