
TextBrewer is a PyTorch-based model distillation toolkit for natural language processing.
It includes various distillation techniques from both NLP and CV field and provides an easy-to-use distillation framework, which allows users to quickly experiment with the state-of-the-art distillation methods to compress the model with a relatively small sacrifice in the performance, increasing the inference speed and reducing the memory usage.
Main features¶
Wide-support : it supports various model architectures (especially transformer-based models).
Flexibility : design your own distillation scheme by combining different techniques.
Easy-to-use : users don’t need to modify the model architectures.
Built for NLP : it is suitable for a wide variety of NLP tasks: text classification, machine reading comprehension, sequence labeling, …
Paper: TextBrewer: An Open-Source Knowledge Distillation Toolkit for Natural Language Processing
Tutorial¶
Introduction¶
Textbrewer is designed for the knowledge distillation of NLP models. It provides various distillation methods and offers a distillation framework for quickly setting up experiments.
TextBrewer currently is shipped with the following distillation techniques:
Mixed soft-label and hard-label training
Dynamic loss weight adjustment and temperature adjustment
Various distillation loss functions
Freely adding intermediate features matching losses
Multi-teacher distillation
…
TextBrewer includes:
Distillers: the cores of distillation. Different distillers perform different distillation modes.
Configurations and presets: Configuration classes for training and distillation, and predefined distillation loss functions and strategies.
Utilities: auxiliary tools such as model parameters analysis.
Architecture¶

Installation¶
Requirements
Python >= 3.6
PyTorch >= 1.1.0
TensorboardX or Tensorboard
NumPy
tqdm
Transformers >= 2.0 (optional, used by some examples)
Install from PyPI
pip install textbrewer
Install from the Github source
git clone https://github.com/airaria/TextBrewer.git pip install ./textbrewer
Workflow¶


To start distillation, users need to provide
the models (the trained teacher model and the un-trained student model).
datasets and experiment configurations.
Stage 1: Preparation:
Train the teacher model.
Define and initialize the student model.
Construct a dataloader, an optimizer, and a learning rate scheduler.
Stage 2: Distillation with TextBrewer:
Construct a
TraningConfig
and aDistillationConfig
, initialize a distiller.Define an adaptor and a callback. The adaptor is used for the adaptation of model inputs and outputs. The callback is called by the distiller during training.
Call the :
train
method of the distiller.
Quickstart¶
Here we show the usage of TextBrewer by distilling BERT-base to a 3-layer BERT.
Before distillation, we assume users have provided:
A trained teacher model
teacher_model
(BERT-base) and a to-be-trained student modelstudent_model
(3-layer BERT).a
dataloader
of the dataset, anoptimizer
and a learning rate builder or classscheduler_class
and its args dictscheduler_dict
.
Distill with TextBrewer:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
# Show the statistics of model parameters
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)
print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)
# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
# The second and third elements of model outputs are the logits and hidden states
return {'logits': model_outputs[1],
'hidden': model_outputs[2]}
# Training configuration
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
# We match 0-0 and 8-2 here for demonstration
distill_config = DistillationConfig(
intermediate_matches=[
{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
# Build distiller
distiller = GeneralDistiller(
train_config=train_config, distill_config = distill_config,
model_T = teacher_model, model_S = student_model,
adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
# Start!
with distiller:
distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=None)
Examples¶
Examples can be found in the examples directory of the repo:
examples/random_token_example : a simple runnable toy example which demonstrates the usage of TextBrewer. This example performs distillation on the text classification task with random tokens as inputs.
examples/cmrc2018_example (Chinese): distillation on CMRC 2018, a Chinese MRC task, using DRCD as data augmentation.
examples/mnli_example (English): distillation on MNLI, an English sentence-pair classification task. This example also shows how to perform multi-teacher distillation.
examples/conll2003_example (English): distillation on CoNLL-2003 English NER task, which is in the form of sequence labeling.
examples/msra_ner_example (Chinese): This example distills a Chinese-ELECTRA-base model on the MSRA NER task with distributed data-parallel training(single node, muliti-GPU).
FAQ¶
Q: How to initialize the student model?
A: The student model could be randomly initialized (i.e., with no prior knowledge) or be initialized by pre-trained weights. For example, when distilling a BERT-base model to a 3-layer BERT, you could initialize the student model with RBT3 (for Chinese tasks) or the first three layers of BERT (for English tasks) to avoid cold start problem. We recommend that users use pre-trained student models whenever possible to fully take advantage of large-scale pre-training.
Q: How to set training hyperparameters for the distillation experiments?
A: Knowledge distillation usually requires more training epochs and a larger learning rate than training on the labeled dataset. For example, training SQuAD on BERT-base usually takes 3 epochs with lr=3e-5; however, distillation takes 30~50 epochs with lr=1e-4. The conclusions are based on our experiments, and you are advised to try on your own data.
Q: My teacher model and student model take different inputs (they do not share vocabularies), so how can I distill?
A: You need to feed different batches to the teacher and the student. See Feed Different batches to Student and Teacher, Feed Cached Values.
Q: I have stored the logits from my teacher model. Can I use them in the distillation to save the forward pass time?
A: Yes, see Feed Different batches to Student and Teacher, Feed Cached Values.
Known Issues¶
Multi-label classification is not supported.
Citation¶
If you find TextBrewer is helpful, please cite our paper :
@InProceedings{textbrewer-acl2020-demo,
author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
year = "2020",
publisher = "Association for Computational Linguistics"
}
Core Concepts¶
Conventions¶
Model_T
an instance oftorch.nn.Module
, the teacher model that to be distilled.Model_S
: an instance oftorch.nn.Module
, the student model, usually smaller than the teacher model for model compression and faster inference speed.optimizer
: an instance oftorch.optim.Optimizer
.scheduler
: an instance of a class undertorch.optim.lr_scheduler
, allows flexible adjustment of learning rate.dataloader
: data iterator, used to generate data batches. A batch can be a tuple or a dictfor batch in dataloader: # if batch_postprocessor is not None: batch = batch_postprocessor(batch) # check batch datatype # passes batch to the model and adaptors
Batch Format(important)¶
Foward conventions: each batch to be passed to the model should be a dict
or tuple
:
if the batch is a
dict
, the model will be called asmodel(**batch, **args)
;if the batch is a
tuple
, the model is called asmodel(*batch, **args)
.
Hence if the batch is not a dict, users should make sure that the order of each element in the batch is the same as the order of the arguments of model.forward
. args
is used for passing additional parameters.
Users can additionaly define a batch_postprocessor
function to post-process batches if needed. batch_postprocessor
should take a batch and return a batch. See the explanation on train
method of Distillers for more details.
Since version 0.2.1, TextBrewer supports more flexible inputs scheme: users can feed different batches to student and teacher, or feed the cached values to save the forward pass time. See Feed Different batches to Student and Teacher, Feed Cached Values.
Configurations¶
TrainingConfig
: configurations related to general deep learning model training.DistillationConfig
: configurations related to distillation methods.
Distillers¶
Distillers are in charge of conducting the actual experiments. The following distillers are available:
BasicDistiller
: single-teacher single-task distillation, provides basic distillation strategies.GeneralDistiller
(Recommended): single-teacher single-task distillation, supports intermediate features matching. Recommended most of the time.MultiTeacherDistiller
: multi-teacher distillation, which distills multiple teacher models (of the same task) into a single student model. This class doesn’t support Intermediate features matching.MultiTaskDistiller
: multi-task distillation, which distills multiple teacher models (of different tasks) into a single student.BasicTrainer
: Supervised training a single model on a labeled dataset, not for distillation. It can be used to train a teacher model.
User-Defined Functions¶
In TextBrewer, there are two functions that should be implemented by users: callback()
and adaptor()
.
-
callback
(model, step) → None¶ At each checkpoint, after saving the student model, the callback function will be called by the distiller. callback can be used to evaluate the performance of the student model at each checkpoint.
Note
If users want to do an evaluation in the callback, remember to add
model.eval()
in the callback.- Parameters
model (torch.nn.Module) – the student model
step (int) – the current training step
-
adaptor
(batch, model_inputs) → dict¶ It converts the model inputs and outputs to the specified format so that they can be recognized by the distiller. At each training step, batch and model outputs will be passed to the adaptor; adaptor reorganize the data and returns a dict.
The functionality of the adaptor is shown in the figure below:
- Parameters
batch – the input batch to the model
model_outputs – the outputs returned by the model
- Return type
dict
- Returns
a dictionary that may contain the following keys and values:
’logits’ :
List[torch.Tensor]
ortorch.Tensor
The inputs to the final softmax. Each tensor should have the shape (batch_size, num_labels) or (batch_size, length, num_labels).
’logits_mask’:
List[torch.Tensor]
ortorch.Tensor
0/1 matrix, which masks logits at specified positions. The positions where mask==0 won’t be included in the calculation of loss on logits. Each tensor should have the shape (batch_size, length).
’labels’:
List[torch.Tensor]
ortorch.Tensor
Ground-truth labels of the examples. Each tensor should have the shape (batch_size,) or (batch_size, length).
Note
logits_mask only works for logits with shape (batch_size, length, num_labels). It’s used to mask along the length dimension, commonly used in sequence labeling tasks.
logits, logits_mask and labels should either all be lists of tensors, or all be tensors.
’losses’ :
List[torch.Tensor]
It stores pre-computed losses, for example, the cross-entropy between logits and ground-truth labels. All the losses stored here would be summed and weighted by hard_label_weight and added to the total loss. Each tensor in the list should be a scalar.
’attention’:
List[torch.Tensor]
List of attention matrices, used to compute intermediate feature matching loss. Each tensor should have the shape (batch_size, num_heads, length, length) or (batch_size, length, length), depending on what attention loss is used. Details about various loss functions can be found at Intermediate Losses.
’hidden’:
List[torch.Tensor]
List of hidden states used to compute intermediate feature matching loss. Each tensor should have the shape (batch_size, length, hidden_dim).
’inputs_mask’ :
torch.Tensor
0/1 matrix, performs masking on attention and hidden, should have the shape (batch_size, length).
Note
These keys are all optional:
If there is no inputs_mask or logits_mask, then it’s considered as no masking.
If not there is no intermediate feature matching loss, you can ignore attention and hidden.
If you don’t want to add loss of the original hard labels, you can set
hard_label_weight=0
in theDistillationConfig
and ignore losses.If logits is not provided, the KD loss of the logits will be omitted.
labels is required if and only if
probability_shift==True
.You shouldn’t ignore all the keys, otherwise the training won’t start :)
In most cases logits should be provided, unless you are doing multi-stage training or non-classification tasks, etc.
Example:
''' Suppose the model outputs are: logits, sequence_output, total_loss class MyModel(): def forward(self, input_ids, attention_mask, labels, ...): ... return logits, sequence_output, total_loss logits: Tensor of shape (batch_size, num_classes) sequence_output: List of tensors of (batch_size, length, hidden_dim) total_loss: scalar tensor model inputs are: input_ids = batch[0] : input_ids (batch_size, length) attention_mask = batch[1] : attention_mask (batch_size, length) labels = batch[2] : labels (batch_size, num_classes) ''' def SimpleAdaptor(batch, model_outputs): return {'logits': (model_outputs[0],), 'hidden': model.outputs[1], 'inputs_mask': batch[1]}
Feed Different batches to Student and Teacher, Feed Cached Values¶
Feed Different batches¶
In some cases, student and teacher read different inputs. For example, if you distill a RoBERTa model to a BERT model, they cannot share the inputs since they have different vocabularies.
To solve this, one can build a dataset that returns a dict as the batch with keys 'student'
and 'teacher'
.
TextBrewer will unpack the dict, and feeds batch['student']
to the student and its adaptor, feeds batch['teacher']
to the teacher and its adaptor, following the forward conventions.
Here is an example.
import torch from torch.utils.data import Dataset, TensorDataset, DataLoader class TSDataset(Dataset): def __init__(self, teacher_dataset, student_dataset): # teacher_dataset and student_dataset are normal datasets # whose each element is a tuple or a dict. assert len(teacher_dataset) == len(student_dataset), \ f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!" self.teacher_dataset = teacher_dataset self.student_dataset = student_dataset def __len__(self): return len(self.teacher_dataset) def __getitem__(self,i): return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i]} teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3)) student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2)) tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset) dataloader = DataLoader(dataset=tsdataset, ... )
Feed Cached Values¶
If you are ready to provide a dataset that returns dict with keys 'student'
and 'teacher'
like the one above, you can also add a another key 'teacher_cache'
, which stores the pre-computed outputs from the teacher. Then TextBrewer will treat batch['teacher_cache']
as the output from the teacher and feed it to the teacher’s adaptor. No teacher’s forward will be called.
Here is an example.
import torch from torch.utils.data import Dataset, TensorDataset, DataLoader class TSDataset(Dataset): def __init__(self, teacher_dataset, student_dataset, teacher_cache): # teacher_dataset and student_dataset are normal datasets # whose each element is a tuple or a dict. # teacher_cache is a list of items; each item is the output from the teacher. assert len(teacher_dataset) == len(student_dataset), \ f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!" assert len(teacher_dataset) == len(teacher_cache), \ f"lengths of teacher_dataset {len(teacher_dataset)} and teacher_cache {len(teacher_cache)} are not the same!" self.teacher_dataset = teacher_dataset self.student_dataset = student_dataset self.teacher_cache = teacher_cache def __len__(self): return len(self.teacher_dataset) def __getitem__(self,i): return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i], 'teacher_cache':self.teacher_cache[i]} teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3)) student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2)) # We make some fake data and assume teacher model outputs are (logits, loss) fake_logits = [torch.randn(3) for _ in range(32)] fake_loss = [torch.randn(1)[0] for _ in range(32)] teacher_cache = [(fake_logits[i],fake_loss[i]) for i in range(32)] tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset, teacher_cache=teacher_cache) dataloader = DataLoader(dataset=tsdataset, ... )
Experiments¶
We have performed distillation experiments on several typical English and Chinese NLP datasets. The setups and configurations are listed below.
Models¶
For English tasks, the teacher model is BERT-base-cased.
For Chinese tasks, the teacher models are RoBERTa-wwm-ext and Electra-base released by the Joint Laboratory of HIT and iFLYTEK Research.
We have tested different student models. To compare with public results, the student models are built with standard transformer blocks except for BiGRU which is a single-layer bidirectional GRU. The architectures are listed below. Note that the number of parameters includes the embedding layer but does not include the output layer of each specific task.
English models¶
Model | #Layers | Hidden size | Feed-forward size | #Params | Relative size |
---|---|---|---|---|---|
BERT-base-cased (teacher) | 12 | 768 | 3072 | 108M | 100% |
T6 (student) | 6 | 768 | 3072 | 65M | 60% |
T3 (student) | 3 | 768 | 3072 | 44M | 41% |
T3-small (student) | 3 | 384 | 1536 | 17M | 16% |
T4-Tiny (student) | 4 | 312 | 1200 | 14M | 13% |
T12-nano (student) | 12 | 256 | 1024 | 17M | 16% |
BiGRU (student) | - | 768 | - | 31M | 29% |
Chinese models¶
Model | #Layers | Hidden size | Feed-forward size | #Params | Relative size |
---|---|---|---|---|---|
RoBERTa-wwm-ext (teacher) | 12 | 768 | 3072 | 102M | 100% |
Electra-base (teacher) | 12 | 768 | 3072 | 102M | 100% |
T3 (student) | 3 | 768 | 3072 | 38M | 37% |
T3-small (student) | 3 | 384 | 1536 | 14M | 14% |
T4-Tiny (student) | 4 | 312 | 1200 | 11M | 11% |
Electra-small (student) | 12 | 256 | 1024 | 12M | 12% |
T6 archtecture is the same as DistilBERT[1], BERT6-PKD[2], and BERT-of-Theseus[3].
T4-tiny archtecture is the same as TinyBERT[4].
T3 architecure is the same as BERT3-PKD[2].
Configurations¶
Distillation Configurations¶
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
# Others arguments take the default values
matches
are differnt for different models:
Model | matches |
---|---|
BiGRU | None |
T6 | L6_hidden_mse + L6_hidden_smmd |
T3 | L3_hidden_mse + L3_hidden_smmd |
T3-small | L3n_hidden_mse + L3_hidden_smmd |
T4-Tiny | L4t_hidden_mse + L4_hidden_smmd |
T12-nano | small_hidden_mse + small_hidden_smmd |
Electra-small | small_hidden_mse + small_hidden_smmd |
The definitions of matches
are at exmaple/matches/matches.py.
We use GeneralDistiller
in all the distillation experiments.
Training Configurations¶
Learning rate is 1e-4 (unless otherwise specified).
We train all the models for 30~60 epochs.
Results on English Datasets¶
We experiment on the following typical Enlgish datasets:
Dataset | Task type | Metrics | #Train | #Dev | Note |
---|---|---|---|---|---|
MNLI | text classification | m/mm Acc | 393K | 20K | sentence-pair 3-class classification |
SQuAD 1.1 | reading comprehension | EM/F1 | 88K | 11K | span-extraction machine reading comprehension |
CoNLL-2003 | sequence labeling | F1 | 23K | 6K | named entity recognition |
We list the public results from DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT and our results below for comparison.
Public results:
Model (public) | MNLI | SQuAD | CoNLL-2003 |
---|---|---|---|
DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
BERT6-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | - |
BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
BERT3-PKD (T3) | 76.7 / 76.3 | - | - |
TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | - |
Our results (see Experimental Results for details):
Model (ours) | MNLI | SQuAD | CoNLL-2003 |
---|---|---|---|
BERT-base-cased (teacher) | 83.7 / 84.0 | 81.5 / 88.6 | 91.1 |
BiGRU | - | - | 85.3 |
T6 | 83.5 / 84.0 | 80.8 / 88.1 | 90.7 |
T3 | 81.8 / 82.7 | 76.4 / 84.9 | 87.5 |
T3-small | 81.3 / 81.7 | 72.3 / 81.4 | 78.6 |
T4-tiny | 82.0 / 82.6 | 75.2 / 84.0 | 89.1 |
T12-nano | 83.2 / 83.9 | 79.0 / 86.6 | 89.6 |
Note:
The equivalent model structures of public models are shown in the brackets after their names.
When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD and HotpotQA is used for data augmentation on CoNLL-2003.
When distilling to T12-nano, HotpotQA is used for data augmentation on CoNLL-2003.
Results on Chinese Datasets¶
We experiment on the following typical Chinese datasets:
Dataset | Task type | Metrics | #Train | #Dev | Note |
---|---|---|---|---|---|
XNLI | text classification | Acc | 393K | 2.5K | Chinese translation version of MNLI |
LCQMC | text classification | Acc | 239K | 8.8K | sentence-pair matching, binary classification |
CMRC 2018 | reading comprehension | EM/F1 | 10K | 3.4K | span-extraction machine reading comprehension |
DRCD | reading comprehension | EM/F1 | 27K | 3.5K | span-extraction machine reading comprehension (Traditional Chinese) |
MSRA NER | sequence labeling | F1 | 45K | 3.4K (test) | Chinese named entity recognition |
The results are listed below (see Experimental Results for details).
Model | XNLI | LCQMC | CMRC 2018 | DRCD |
---|---|---|---|---|
RoBERTa-wwm-ext (teacher) | 79.9 | 89.4 | 68.8 / 86.4 | 86.5 / 92.5 |
T3 | 78.4 | 89.0 | 66.4 / 84.2 | 78.2 / 86.4 |
T3-small | 76.0 | 88.1 | 58.0 / 79.3 | 75.8 / 84.8 |
T4-tiny | 76.2 | 88.4 | 61.8 / 81.8 | 77.3 / 86.1 |
Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER |
---|---|---|---|---|---|
Electra-base (teacher) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 |
Electra-small | 77.7 | 89.3 | 66.5 / 84.9 | 85.5 / 91.3 | 93.48 |
Note:
Learning rate decay is not used in distillation on CMRC 2018 and DRCD.
CMRC 2018 and DRCD take each other as the augmentation dataset in the distillation.
The settings of training Electra-base teacher model can be found at Chinese-ELECTRA.
Electra-small student model is intialized with the pretrained weights.
Configurations¶
TrainingConfig¶
-
class
textbrewer.
TrainingConfig
(gradient_accumulation_steps=1, ckpt_frequency=1, ckpt_epoch_frequency=1, ckpt_steps=None, log_dir=None, output_dir='./saved_models', device='cuda', fp16=False, fp16_opt_level='O1', data_parallel=False, local_rank=-1)[source]¶ Configurations related to general model training.
- Parameters
gradient_accumulation_steps (int) – accumulates gradients before optimization to reduce GPU memory usage. It calls
optimizer.step()
every gradient_accumulation_steps backward steps.ckpt_frequency (int) – stores model weights ckpt_frequency times every epoch.
ckpt_epoch_frequency (int) – stores model weights every ckpt_epoch_frequency epochs.
ckpt_steps (int) – if num_steps is passes to
distiller.train()
, saves the model every ckpt_steps, meanwhile ignore ckpt_frequency and ckpt_epoch_frequency .log_dir (str) – directory to save the tensorboard log file. Set it to
None
to disable tensorboard.output_dir (str) – directory to save model weights.
device (str or torch.device) – training on CPU or GPU.
fp16 (bool) – if
True
, enables mixed precision training using Apex.fp16_opt_level (str) – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. See Apex documenation for details.
data_parallel (bool) – If
True
, wraps the models withtorch.nn.DataParallel
.local_rank (int) – the local rank of the current processes. A non-nagative value means that we are in the distributed training mode with
DistributedDataParallel
.
Note
To perform data parallel (DP) training, you could either wrap the models with
torch.nn.DataParallel
outside TextBrewer by yourself, or leave the work for TextBrewer by setting data_parallel toTrue
.To enable both data parallel training and mixed precision training, you should set data_parallel to
True
, and DO NOT wrap the models by yourself.In some experiments, we have observed slowing down in the speed with
torch.nn.DataParallel
.To perform distributed data parallel (DDP) training, you should call
torch.distributed.init_process_group
before intializing a TrainingConfig; and pass the raw (unwrapped) model when initializing the distiller.DP and DDP are mutual exclusive.
Example:
# Usually just need to set log_dir and output_dir and leave others default train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir) # Stores the model at the end of each epoch train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=1) # Stores the model twice (at the middle and at the end) in each epoch train_config = TrainingConfig(ckpt_frequency=2, ckpt_epoch_frequency=1) # Stores the model once every two epochs train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=2)
-
classmethod
from_dict
(dict_object)¶ Construct configurations from a dict.
-
classmethod
from_json_file
(json_filename)¶ Construct configurations from a json file.
DistillationConfig¶
-
class
textbrewer.
DistillationConfig
(temperature=4, temperature_scheduler='none', hard_label_weight=0, hard_label_weight_scheduler='none', kd_loss_type='ce', kd_loss_weight=1, kd_loss_weight_scheduler='none', probability_shift=False, intermediate_matches: Optional[List[Dict]] = None, is_caching_logits=False)[source]¶ Configurations related to distillation methods. It defines the total loss to be optimized:
\[\mathcal{L}_{total}= \mathcal{L}_{KD} * w_{KD} + \mathcal{L}_{hl} * w_{hl} + sum(\textrm{intermediate_losses})\]where
\(\mathcal{L}_{KD}\) is the KD loss on logits, \(w_{KD}\) is its weight;
\(\mathcal{L}_{hl}\) is the sum of
losses
returned by the adaptor and \(w_{hl}\) is its weight;intermediate_losses is defined via intermediate_matches.
- Parameters
temperature (float) – temperature for the distillation. The teacher and student models’ logits will be divided by the temperature in computing the KD loss. The temperature typicially ranges from 1 to 10. We found that temperature higher than 1 usually leads to better performance.
temperature_scheduler – dynamically adjusts temperature. See
TEMPERATURE_SCHEDULER
for all available options.kd_loss_type (str) – KD loss function for the
logits
term returned by the adaptor, can be'ce'
or'mse'
. SeeKD_LOSS_MAP
.kd_loss_weight (float) – the weight for the KD loss.
hard_label_weight (float) – the weight for the sum of
losses
term returned by the adaptor.losses
may include the losses on the ground-truth labels and other user-defined losses.kd_loss_weight_scheduler – Dynamically adjusts KD loss weight. See
WEIGHT_SCHEDULER
for all available options.hard_label_weight_scheduler – Dynamically adjusts the weight of the sum of
losses
. SeeWEIGHT_SCHEDULER
for all available options.probability_shift (bool) – if
True
, switch the ground-truth label’s logit and the largest logit predicted by the teacher, to make the ground-truth label’s logit largest. Requireslabels
term returned by the adaptor.is_caching_logits (bool) – if
True
, caches the batches and the output logits of the teacher model in memory, so that those logits will only be calcuated once. It will speed up the distillation process. This feature is only available forBasicDistiller
andMultiTeacherDistiller
, and only when distillers’train()
method is called withnum_steps=None
. It is suitable for small and medium datasets.intermediate_matches (List[Dict]) – Configuration for intermediate feature matching. Each element in the list is a dict, representing a pair of matching config.
The dict in intermediate_matches contains the following keys:
‘layer_T’: layer_T (int): selects the layer_T-th layer of teacher model.
‘layer_S’: layer_S (int): selects the layer_S-th layer of student model.
Note
layer_T and layer_S indicate layers in
attention
orhidden
list in the returned dict of the adaptor, rather than the actual layers in the model.If the loss is
fst
ornst
, two layers have to be chosen from the teacher and the student respectively. In this case, layer_T and layer_S are lists of two ints. See the example below.
‘feature’: feature (str): features of intermediate layers. It can be:
‘attention’ : attention matrix, of the shape (batch_size, num_heads, length, length) or (batch_size, length, length)
‘hidden’:hidden states, of the shape (batch_size, length, hidden_dim).
‘loss’ : loss (str) : loss function. See
MATCH_LOSS_MAP
for available losses. Currently includes:'attention_mse'
,'attention_ce'
,'hidden_mse'
,'nst'
, etc.‘weight’: weight (float) : weight for the loss.
‘proj’ : proj (List, optional) : if the teacher and the student have the same feature dimension, it is optional; otherwise it is required. It is the mapping function to match teacher and student intermediate feature dimension. It is a list, with these elements:
proj[0] (str): mapping function, can be
'linear'
,'relu'
,'tanh'
. SeePROJ_MAP
.proj[1] (int): feature dimension of student model.
proj[2] (int): feature dimension of teacher model.
proj[3] (dict): optional, provides configurations such as learning rate. If not provided, the learning rate and optimizer configurations will follow the default config of the optimizer, otherwise it will use the ones specified here.
Example:
from textbrewer import DistillationConfig # simple configuration: use default values, or try different temperatures distill_config = DistillationConfig(temperature=8) # adding intermediate feature matching # under this setting, the returned dict results_T/S of adaptor_T/S should contain 'hidden' key. # The mse loss between teacher's results_T['hidden'][10] and student's results_S['hidden'][3] will be computed distill_config = DistillationConfig( temperature=8, intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden', 'loss':'hidden_mse', 'weight':1}] ) # multiple inatermediate feature matching. The teacher and the student have a hidden_dim of 768 and 384 respectively. distill_config = DistillationConfig( temperature = 8, intermediate_matches = [ \ {'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':4, 'layer_S':1, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':8, 'layer_S':2, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':12, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}] )
-
classmethod
from_dict
(dict_object)¶ Construct configurations from a dict.
-
classmethod
from_json_file
(json_filename)¶ Construct configurations from a json file.
Distillers¶
Distillers perform the actual experiments.
Initialize a distiller object, call its train method to start training/distillation.
BasicDistiller¶
-
class
textbrewer.
BasicDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ Performs single-teacher single-task distillation, provides basic distillation strategies.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (
torch.nn.Module
) – teacher model.model_S (
torch.nn.Module
) – student model.adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloader – dataset iterator.
num_epochs (int) – number of training epochs.
num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to evaluate the model at each checkpoint.batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
**args – additional arguments fed to the model.
Note
If the batch is a list or tuple, model is called as:
model(*batch, **args)
. Make sure the order of elements in the batch matches their order inmodel.forward
.If the batch is a dict, model is called as:
model(**batch,**args)
. Make sure the keys of the batch match the arguments of themodel.forward
.
Note
If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:
from transformers import get_linear_schedule_with_warmup distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})
GeneralDistiller¶
-
class
textbrewer.
GeneralDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S, custom_matches: Optional[List[textbrewer.distiller_utils.CustomMatch]] = None)[source]¶ Supports intermediate features matching. Recommended for single-teacher single-task distillation.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (
torch.nn.Module
) – teacher model.model_S (
torch.nn.Module
) – student model.adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
custom_matches (list) – supports more flexible user-defined matches (testing).
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloader – dataset iterator.
num_epochs (int) – number of training epochs.
num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to evaluate the model at each checkpoint.batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
**args – additional arguments fed to the model.
Note
If the batch is a list or tuple, model is called as:
model(*batch, **args)
. Make sure the order of elements in the batch matches their order inmodel.forward
.If the batch is a dict, model is called as:
model(**batch,**args)
. Make sure the keys of the batch match the arguments of themodel.forward
.
Note
If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:
from transformers import get_linear_schedule_with_warmup distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})
MultiTeacherDistiller¶
-
class
textbrewer.
MultiTeacherDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ Distills multiple teacher models (of the same tasks) into a student model. It doesn’t support intermediate feature matching.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (List[torch.nn.Module]) – list of teacher models.
model_S (torch.nn.Module) – student model.
adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(self, optimizer, scheduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args)¶ trains the student model. See
BasicDistiller.train()
.
MultiTaskDistiller¶
-
class
textbrewer.
MultiTaskDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ distills multiple teacher models (of different tasks) into a single student. It supports intermediate feature matching since 0.2.1.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (dict) – dict of teacher models: {task1:model1, task2:model2, …. }. Keys are tasknames.
model_S (torch.nn.Module) – student model.
adaptor_T (dict) – dict of teacher adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.
adaptor_S (dict) – dict of student adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.
-
train
(optimizer, dataloaders, num_steps, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, tau=1, callback=None, batch_postprocessors=None, **args)[source]¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloaders (dict) – dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
num_steps (int) – number of training steps.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
tau (float) – the probability of sampling an example from task d is proportional to |d|^{tau}, where |d| is the size of d’s training set. If the size of any dataset is unknown, ignores tau and samples examples unifromly from each dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to do evaluation of the model at each checkpoint.batch_postprocessors (dict) – a dict of batch_postprocessors. Keys are tasknames, values are corresponding batch_postprocessors. Each batch_postprocessor should take a batch and return a batch.
**args – additional arguments fed to the model.
BasicTrainer¶
-
class
textbrewer.
BasicTrainer
(train_config: textbrewer.configurations.TrainingConfig, model: torch.nn.modules.module.Module, adaptor)[source]¶ It performs supervised training, not distillation. It can be used for training the teacher model.
- Parameters
train_config (
TrainingConfig
) – training configuration.model (
torch.nn.Module
) – model to be trained.adaptor (Callable) –
The role of adaptor is explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]¶ trains the model. See
BasicDistiller.train()
.
Presets¶
Presets include module variables that define pre-defined loss functions and strategies.
Module variables¶
ADAPTOR_KEYS¶
-
textbrewer.presets.
ADAPTOR_KEYS
¶ (list) valid keys of the dict returned by the adaptor, includes:
‘logits’
‘logits_mask’
‘losses’
‘inputs_mask’
‘labels’
‘hidden’
‘attention’
KD_LOSS_MAP¶
-
textbrewer.presets.
KD_LOSS_MAP
¶ (dict) available KD losses
‘mse’ : mean squared error
‘ce’: cross-entropy loss
PROJ_MAP¶
-
textbrewer.presets.
PROJ_MAP
¶ (dict) layers used to match the different dimensions of intermediate features
‘linear’ : linear layer, no activation
‘relu’ : ReLU activation
‘tanh’: Tanh activation
MATCH_LOSS_MAP¶
-
textbrewer.presets.
MATCH_LOSS_MAP
¶ (dict) intermediate feature matching loss functions, includes:
hidden_mse
nst
,mmd
See Intermediate Losses for details.
WEIGHT_SCHEDULER¶
-
textbrewer.presets.
WEIGHT_SCHEDULER
¶ (dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight.
‘linear_decay’ : decay from 1 to 0 during the whole training process.
‘linear_growth’ : grow from 0 to 1 during the whole training process.
TEMPERATURE_SCHEDULER¶
-
textbrewer.presets.
TEMPERATURE_SCHEDULER
¶ (custom dict) used to dynamically adjust distillation temperature.
‘constant’ : Constant temperature.
‘flsw’ : See Preparing Lessons: Improve Knowledge Distillation with Better Supervision. Needs parameters
beta
andgamma
.‘cwsm’: See Preparing Lessons: Improve Knowledge Distillation with Better Supervision. Needs parameter
beta
.
Different from other options, when using
'flsw'
and'cwsm'
, you need to provide extra parameters, for example:#flsw distill_config = DistillationConfig( temperature_scheduler = ['flsw', 1, 2] # beta=1, gamma=2 ) #cwsm distill_config = DistillationConfig( temperature_scheduler = ['cwsm', 1] # beta = 1 )
Customization¶
If the pre-defined modules do not satisfy your requirements, you can add your own defined modules to the above dict.
For example:
MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss
WEIGHT_SCHEDULER['my_weight_scheduler'] = my_weight_scheduler
then used in DistillationConfig
:
distill_config = DistillationConfig(
kd_loss_weight_scheduler = 'my_weight_scheduler'
intermediate_matches = [{'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'my_L1_loss', 'weight' : 1}]
...)
Refer to the source code for more details on inputs and outputs conventions (will be explained in detail in a later version of the documentation).
Intermediate Losses¶
Here we list the definitions of pre-defined intermediate losses.
Usually, users don’t need to refer to these functions directly, but refer to them by the names in MATCH_LOSS_MAP
.
attention_mse¶
-
textbrewer.losses.
att_mse_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_mse_sum¶
-
textbrewer.losses.
att_mse_sum_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the the shape is (batch_size, num_heads, length, length), sums along the num_heads dimension and then calcuates the mse loss between the two matrices.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce¶
-
textbrewer.losses.
att_ce_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on
dim=-1
.If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce_mean¶
-
textbrewer.losses.
att_ce_mean_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on
dim=-1
.If the shape is (batch_size, num_heads, length, length), averages over dimension num_heads and then computes cross-entropy loss between the two matrics.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.tensor) – tensor of shape (batch_size, length)
cos¶
-
textbrewer.losses.
cos_loss
(state_S, state_T, mask=None)[source]¶ Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see DistilBERT
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
- Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
mask (torch.Tensor) – tensor of shape (batch_size, length)
pkd¶
-
textbrewer.losses.
pkd_loss
(state_S, state_T, mask=None)[source]¶ Computes normalized vector mse loss at position 0 along length dimension. This is the loss used in BERT-PKD, see Patient Knowledge Distillation for BERT Model Compression.
If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
If the input tensors are of shape (batch_size, hidden_size), it directly computes the loss between tensors without taking the hidden states at position 0.
- Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
mask – not used.
nst (mmd)¶
-
textbrewer.losses.
mmd_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrices state_S and state_T. Each list contains 2 matrices of the shape (batch_size, length, hidden_size). hidden_size of matrices in State_S doesn’t need to be the same as that of state_T. Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, length, length) ) and the ones in B ( with the resulting shape (batch_size, length, length) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)\]It is a Variant of the NST loss in Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1}, ...]
fsp (gram)¶
-
textbrewer.losses.
fsp_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrics state_S and state_T. Each list contains two matrices of the shape (batch_size, length, hidden_size). Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, hidden_size, hidden_size) ) and the ones in B ( with the resulting shape (batch_size, hidden_size, hidden_size) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)\]It is a Variant of FSP loss in A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning.
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
- Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]}, ...]
Model Utils¶
display_parameters¶
-
textbrewer.utils.
display_parameters
(model, max_level=None)[source]¶ Display the numbers and memory usage of module parameters.
- Parameters
model (torch.nn.Module or dict) – the model to be inspected.
max_level (int or None) – The max level to display. If
max_level==None
, show all the levels.
- Returns
A formatted string and a
LayerNode
object representing the model.
Data Utils¶
This module provides the following data augmentation methods.
masking¶
-
textbrewer.data_utils.
masking
(tokens, p=0.1, mask='[MASK]')[source]¶ Returns a new list by replacing elements in tokens by mask with probability p.
- Parameters
tokens (list) – list of tokens or token ids.
p (float) – probability to mask each element in tokens.
- Returns
A new list by replacing elements in tokens by mask with probability p.
deleting¶
-
textbrewer.data_utils.
deleting
(tokens, p=0.1)[source]¶ Returns a new list by deleting elements in tokens with probability p.
- Parameters
tokens (list) – list of tokens or token ids.
p (float) – probability to delete each element in tokens.
- Retunrns:
a new list by deleting elements in :tokens with probability p.
n_gram_sampling¶
-
textbrewer.data_utils.
n_gram_sampling
(tokens, p_ng=[0.2, 0.2, 0.2, 0.2, 0.2], l_ng=[1, 2, 3, 4, 5])[source]¶ Samples a length l from l_ng with probability distribution p_ng, then returns a random span of length l from tokens.
- Parameters
tokens (list) – list of tokens or token ids.
p_ng (list) – probability distribution of the n-grams, should sum to 1.
l_ng (list) – specify the n-grams.
- Returns
a n-gram random span from tokens.
short_disorder¶
-
textbrewer.data_utils.
short_disorder
(tokens, p=[0.9, 0.1, 0, 0, 0])[source]¶ Returns a new list by disordering tokens with probability distribution p at every possible position. Let abc be a 3-gram in tokens, there are five ways to disorder, corresponding to five probability values:
abc -> abcabc -> bacabc -> cbaabc -> cababc -> bca- Parameters
tokens (list) – list of tokens or token ids.
p (list) – probability distribution of 5 disorder types, should sum to 1.
- Returns
a new disordered list
long_disorder¶
-
textbrewer.data_utils.
long_disorder
(tokens, p=0.1, length=20)[source]¶ Performs a long-range disordering. If
length>1
, then swaps the two halves of each span of length length in tokens; iflength<=1
, treats length as the relative length. For example:>>>long_disorder([0,1,2,3,4,5,6,7,8,9,10], p=1, length=0.4) [2, 3, 0, 1, 6, 7, 4, 5, 8, 9]
- Parameters
tokens (list) – list of tokens or token ids.
p (list) – probability to swaps the two halves of a spans at possible positions.
length (int or float) – length of the disordered span.
- Returns
a new disordered list
Experimental Results¶
English Datasets¶
MNLI¶
Training without Distillation:
Model(ours) | MNLI |
---|---|
BERT-base-cased | 83.7 / 84.0 |
T3 | 76.1 / 76.5 |
Single-teacher distillation with
GeneralDistiller
:
Model (ours) | MNLI |
---|---|
BERT-base-cased (teacher) | 83.7 / 84.0 |
T6 (student) | 83.5 / 84.0 |
T3 (student) | 81.8 / 82.7 |
T3-small (student) | 81.3 / 81.7 |
T4-tiny (student) | 82.0 / 82.6 |
T12-nano (student) | 83.2 / 83.9 |
Multi-teacher distillation with
MultiTeacherDistiller
:
Model (ours) | MNLI |
---|---|
BERT-base-cased (teacher #1) | 83.7 / 84.0 |
BERT-base-cased (teacher #2) | 83.6 / 84.2 |
BERT-base-cased (teacher #3) | 83.7 / 83.8 |
ensemble (average of #1, #2 and #3) | 84.3 / 84.7 |
BERT-base-cased (student) | 84.8 / 85.3 |
SQuAD¶
Training without Distillation:
Model(ours) | SQuAD |
---|---|
BERT-base-cased | 81.5 / 88.6 |
T6 | 75.0 / 83.3 |
T3 | 63.0 / 74.3 |
Single-teacher distillation with
GeneralDistiller
:
Model(ours) | SQuAD |
---|---|
BERT-base-cased (teacher) | 81.5 / 88.6 |
T6 (student) | 80.8 / 88.1 |
T3 (student) | 76.4 / 84.9 |
T3-small (student) | 72.3 / 81.4 |
T4-tiny (student) | 73.7 / 82.5 |
+ DA | 75.2 / 84.0 |
T12-nano (student) | 79.0 / 86.6 |
Note: When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD.
Multi-teacher distillation with
MultiTeacherDistiller
:
Model (ours) | SQuAD |
---|---|
BERT-base-cased (teacher #1) | 81.1 / 88.6 |
BERT-base-cased (teacher #2) | 81.2 / 88.5 |
BERT-base-cased (teacher #3) | 81.2 / 88.7 |
ensemble (average of #1, #2 and #3) | 82.3 / 89.4 |
BERT-base-cased (student) | 83.5 / 90.0 |
CoNLL-2003 English NER¶
Training without Distillation:
Model(ours) | CoNLL-2003 |
---|---|
BERT-base-cased | 91.1 |
BiGRU | 81.1 |
T3 | 85.3 |
Single-teacher distillation with
GeneralDistiller
:
Model(ours) | CoNLL-2003 |
---|---|
BERT-base-cased (teacher) | 91.1 |
BiGRU | 85.3 |
T6 (student) | 90.7 |
T3 (student) | 87.5 |
+ DA | 90.0 |
T3-small (student) | 78.6 |
+ DA | - |
T4-tiny (student) | 77.5 |
+ DA | 89.1 |
T12-nano (student) | 78.8 |
+ DA | 89.6 |
Note: HotpotQA is used for data augmentation on CoNLL-2003.
Chinese Datasets (RoBERTa-wwm-ext as the teacher)¶
XNLI¶
Model | XNLI |
---|---|
RoBERTa-wwm-ext (teacher) | 79.9 |
T3 (student) | 78.4 |
T3-small (student) | 76.0 |
T4-tiny (student) | 76.2 |
LCQMC¶
Model | LCQMC |
---|---|
RoBERTa-wwm-ext (teacher) | 89.4 |
T3 (student) | 89.0 |
T3-small (student) | 88.1 |
T4-tiny (student) | 88.4 |
CMRC 2018 and DRCD¶
Model | CMRC 2018 | DRCD |
---|---|---|
RoBERTa-wwm-ext (teacher) | 68.8 / 86.4 | 86.5 / 92.5 |
T3 (student) | 63.4 / 82.4 | 76.7 / 85.2 |
+ DA | 66.4 / 84.2 | 78.2 / 86.4 |
T3-small (student) | 46.1 / 71.0 | 71.4 / 82.2 |
+ DA | 58.0 / 79.3 | 75.8 / 84.8 |
T4-tiny (student) | 54.3 / 76.8 | 75.5 / 84.9 |
+ DA | 61.8 / 81.8 | 77.3 / 86.1 |
Note: CMRC 2018 and DRCD take each other as the augmentation dataset on the experiments.
Chinese Datasets (Electra-base as the teacher)¶
Training without Distillation:
Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER |
---|---|---|---|---|---|
Electra-base (teacher) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 |
Electra-small (pretrained) | 72.5 | 86.3 | 62.9 / 80.2 | 79.4 / 86.4 |
Single-teacher distillation with
GeneralDistiller
:
Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER |
---|---|---|---|---|---|
Electra-base (teacher) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 |
Electra-small (random) | 77.2 | 89.0 | 66.5 / 84.9 | 84.8 / 91.0 | |
Electra-small (pretrained) | 77.7 | 89.3 | 66.5 / 84.9 | 85.5 / 91.3 | 93.48 |
Note:
Random: randomly initialized
Pretrained: initialized with pretrained weights
A good initialization of the student (Electra-small) improves the performance.