Summary

Transfer learning is a machine learning method that applies knowledge from a previously trained model to a new, related task, enhancing efficiency and performance in neural network applications, especially when data is scarce. The post addresses the major bottleneck of traditional machine learning by reducing the need for large amounts of labeled data, drawing inspiration from human learning processes to apply learned knowledge from one task to another. The technique varies from using pre-trained models as feature extractors to fine-tuning them for specific tasks, enabling rapid adaptation to new tasks with limited data. Transfer learning represents a significant advancement in machine learning, offering a pragmatic approach to overcoming data scarcity and computational constraints, thereby facilitating the broader application of deep learning models across diverse domains.

Transfer learning is a machine learning technique that utilizes a model already trained for one task on another separate, related task. In this article, we will take a deep dive into what this means, why transfer learning has become increasingly popular to boost neural network performance, and how you can use transfer learning on your own deep learning task.

Motivation

Although traditional machine learning (ML) approaches have shown remarkable performance in a vast spectrum of domains, there remains a common, major bottleneck that these approaches suffer from, which is the availability of data. For supervised learning, models are required to be trained on massive amounts of labelled data before they can start performing decently enough for practical use.

This means that for each problem to be solved, we require the collection and labelling of a large amount of data, which is at best cumbersome, and in some cases, an impossible task, due to limited availability of data and the requirement of manual annotation. Semi-supervised learning tries to partially overcome this bottleneck by bypassing the requirement of labeling of data. However, there are still many domains where even the collection of unlabeled data at scale is challenging.

Traditionally, machine learning problems have been studied in isolation- you study the domain of the problem you want to solve, prepare an appropriate dataset, train an appropriate algorithm on it, and evaluate performance on that problem. But is it possible to utilize already solved problems, so that you do not have to start the ML lifecycle from scratch for a new problem? Think about how we as humans learn similar tasks- we utilize knowledge gained from an already learned task on a related task to speed up our learning process.

Human analogies to transfer learning | Source: Zhuang et al.

For instance, if you have already learned to ride a bicycle, it is easier to learn to drive a motorbike than it would be to learn to drive one from scratch. Similarly, if you have driven a car with automatic transmission, those skills can be transferred when learning to drive a manual car. Transfer learning takes inspiration from this phenomenon and applies it to machine learning. The idea is to take a model already trained on a related task (with readily available data), and use the knowledge it has already learned in its parameters to reduce the data and computation required to learn a new task. This allows machine learning algorithms to be applied to tasks with limited data availability, and also boosts algorithm performance without the need for extensive training.

As an example, suppose we have a Natural Language Processing (NLP) task like sentiment analysis, and assume that we have limited data available for this task. Instead of training a model from scratch, we could take an NLP model already trained on another NLP task, say, named-entity recognition, and tune the model to fit our task. The trained model is expected to have learned low-level features like the semantics of language, and since this knowledge is common for the task at hand, the original training allows the model to generalize at sentiment analysis much faster.

Definition

Formally, transfer learning is defined in terms of domains, tasks, and marginal probabilities (Pan et al.).

A domain D is defined as the tuple of the feature space and the marginal distribution P(X) of a sample data point’s features X = {x1,x2, … ,xn}

A task T is the tuple of the label space γ and the decision function P(Y|X) that is learned by training the model.

The figure below will help visualize these definitions better with some examples.

Exemplifying domains and tasks | Source: Towards Data Science/Dipanjan Sarkar

We define transfer learning in terms of a source domain and source task coupled with a target domain and target task.

Types of Transfer Learning

Based on the different possibilities of the nature of source and target domains and tasks, transfer learning can be categorized into three types:

  1. Inductive Transfer Learning: In inductive transfer learning, the target task is different from the source task, and labelled data for the target domain is available.
  2. Transductive Transfer Learning: In transductive transfer learning, the source and target tasks are the same, and the domains are different. No labelled data is available for the target domain. Substantial labelled data is available for the source domain.
  3. Unsupervised Transfer Learning: Unsupervised transfer learning focuses on solving unsupervised learning tasks in the target domain. Therefore, the target domain is different from but related to the source domain. There are no labels for either of the two.

The image below from Pan et al. provides a deeper understanding of when each of the above strategies is used:

Different Scenarios for Transfer Learning | Source: Pan et al.

Transfer Learning for Neural Networks

Since deep learning systems (neural networks with multiple hidden layers) typically demand extensive training and substantial data compared to traditional machine learning setups, transfer learning plays a crucial role in making it possible (and feasible) to apply deep learning techniques to a diverse set of domains and tasks. Several deep learning networks boasting state-of-the-art performance in fields like computer vision and NLP have emerged through the application of transfer learning, often rivaling or surpassing human capabilities.

Illustrations of Fine-tuning BERT (a popular pre-trained NLP model) on Different Tasks | Source: Devlin, Jacob et al.

Neural network models are a type of inductive learning (different from inductive transfer learning), aiming to derive patterns from a set of provided examples. For instance, in classification tasks, neural network models learn to link input features with corresponding class labels. To ensure these models perform well on new data, neural networks rely on certain assumptions about the distribution of the training data, referred to as inductive bias. Inductive bias influences various aspects of the learning process, including the hypothesis space and the search method within it, thereby shaping what and how the model learns for a given task and domain.

Inductive transfer methods leverage the biases learned from a source task to aid in a target task, which can involve adjusting the biases of the target task by constraining the model’s space, narrowing down hypothesis options, or modifying the search process using insights from the source task.

Applying Transfer Learning to Neural Networks

There are primarily two approaches to applying transfer learning to neural networks, and they essentially differ in how the source model, which is generally a pre-trained model, is modified to fit to the target domain and task. Pre-trained models, such as Inception V3 (Szegedy, Christian, et al.) or BERT (Devlin, Jacob et al.), come already trained on vast datasets for specific tasks like image classification or natural language understanding. Leveraging pre-trained models’ learned features, we can expedite the process of training new models for related tasks. There are two ways this can be done:

Using Pre-trained Models as Feature Extractors

One approach involves using pre-trained models as feature extractors. Neural networks consist of layered architectures, with each layer learning different features progressively. By removing the final layer and preserving the weighted layers of a pre-trained network, we can utilize it as a fixed feature extractor for new tasks. The essence lies in extracting features without altering the model’s weights during training on new data, leveraging knowledge from a source task for a new task. By utilizing these pre-trained models as feature extractors, one can capitalize on the hierarchical representations that the models have learned, saving significant time and computational resources. Using pre-trained models as feature extractors is widely adopted for performing transfer learning on deep neural networks.

Using a pre-trained model as a feature extractor | Source: https://dev.mrdbourke.com

Fine-tuning the Pre-trained Model

The other, more involved approach is to ‘fine-tune’ pre-trained models on the target. Here, we do not just replace the final layer, we also selectively retrain some preceding layers. Deep neural networks have various hyperparameters, where initial layers tend to capture generic features, while later layers focus on task-specific features. By freezing certain layers’ weights while retraining and fine-tuning others to our needs, we capitalize on the network’s architecture and states, resulting in improved performance with reduced training time.

Fine-tuning a pre-trained model | Source: http://d2l.ai
Should I freeze or fine-tune?

Freezing is generally preferred when there is limited labelled target data available, since we do not want to overfit on the target training data and want the model to generalize. If we have adequate labelled data for the target to be able to retrain layers, fine-tuning can result in better performance. One way to assess the tradeoff between freezing and fine-tuning is to set different learning rates for different neural network layers and analyze impact on performance.

Conclusion

In conclusion, transfer learning stands as a powerful strategy for enhancing neural network performance across a myriad of tasks and domains. By leveraging knowledge gained from pre-trained models, deep learning practitioners can effectively tackle new challenges with reduced training time and computational resources. Transfer learning empowers neural networks to harness the wealth of information encapsulated in existing representations, thereby accelerating learning and improving overall performance.

I hope this article serves as a good primer to introduce you to transfer learning. As a next step, I highly recommend walking through this very illustrative tutorial to see transfer learning in action.

(Featured Image by macrovector on Freepik.)

References

  1. Pan, Sinno Jialin, and Qiang Yang. “A survey on transfer learning.” IEEE Transactions on knowledge and data engineering 22.10 (2009): 1345-1359.
  2. F. Zhuang et al., “A Comprehensive Survey on Transfer Learning,” in Proceedings of the IEEE, vol. 109, no. 1, pp. 43-76, Jan. 2021, doi: 10.1109/JPROC.2020.3004555.
  3. Sebastian Ruder, “Transfer Learning – Machine Learning’s Next Frontier”. http://ruder.io/transfer-learning/, 2017.
  4. Ng, Andrew. “Nuts and bolts of building AI applications using Deep Learning.” NIPS Keynote Talk (2016).
  5. Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).
  6. “Transfer Learning with Tensorflow Part 1: Feature Extraction.” Transfer Learning with TensorFlow Part 1: Feature Extraction – Zero to Mastery TensorFlow for Deep Learning, dev.mrdbourke.com/tensorflow-deep-learning/04_transfer_learning_in_tensorflow_part_1_feature_extraction/.
  7. Weiss, Karl, Taghi M. Khoshgoftaar, and DingDing Wang. “A survey of transfer learning.” Journal of Big data 3.1 (2016): 1-40.
  8. “Transfer Learning and Fine-Tuning  :  Tensorflow Core.” TensorFlow, www.tensorflow.org/tutorials/images/transfer_learning.
  9. Szegedy, Christian, et al. “Rethinking the inception architecture for computer vision.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
  10. Sarkar, Dipanjan (DJ). “A Comprehensive Hands-on Guide to Transfer Learning with Real-World Applications in Deep Learning.” Medium, Towards Data Science, 17 Nov. 2018, towardsdatascience.com/a-comprehensive-hands-on-guide-to-transfer-learning-with-real-world-applications-in-deep-learning-212bf3b2f27a.
  11. 14.2. Fine-Tuning – Dive into Deep Learning 1.0.3 Documentation, d2l.ai/chapter_computer-vision/fine-tuning.html. Accessed 5 Feb. 2024.