_images/banner.png

TextBrewer is a PyTorch-based model distillation toolkit for natural language processing.

It includes various distillation techniques from both NLP and CV field and provides an easy-to-use distillation framework, which allows users to quickly experiment with the state-of-the-art distillation methods to compress the model with a relatively small sacrifice in the performance, increasing the inference speed and reducing the memory usage.

Main features

  • Wide-support : it supports various model architectures (especially transformer-based models).

  • Flexibility : design your own distillation scheme by combining different techniques.

  • Easy-to-use : users don’t need to modify the model architectures.

  • Built for NLP : it is suitable for a wide variety of NLP tasks: text classification, machine reading comprehension, sequence labeling, …

Paper: TextBrewer: An Open-Source Knowledge Distillation Toolkit for Natural Language Processing

Tutorial

Introduction

Textbrewer is designed for the knowledge distillation of NLP models. It provides various distillation methods and offers a distillation framework for quickly setting up experiments.

TextBrewer currently is shipped with the following distillation techniques:

  • Mixed soft-label and hard-label training

  • Dynamic loss weight adjustment and temperature adjustment

  • Various distillation loss functions

  • Freely adding intermediate features matching losses

  • Multi-teacher distillation

TextBrewer includes:

  1. Distillers: the cores of distillation. Different distillers perform different distillation modes.

  2. Configurations and presets: Configuration classes for training and distillation, and predefined distillation loss functions and strategies.

  3. Utilities: auxiliary tools such as model parameters analysis.

Architecture

_images/arch.png

Installation

  • Requirements

    • Python >= 3.6

    • PyTorch >= 1.1.0

    • TensorboardX or Tensorboard

    • NumPy

    • tqdm

    • Transformers >= 2.0 (optional, used by some examples)

  • Install from PyPI

    pip install textbrewer
    
  • Install from the Github source

    git clone https://github.com/airaria/TextBrewer.git
    pip install ./textbrewer
    

Workflow

_images/distillation_workflow_en.png _images/distillation_workflow2.png

To start distillation, users need to provide

  1. the models (the trained teacher model and the un-trained student model).

  2. datasets and experiment configurations.

  • Stage 1: Preparation:

    1. Train the teacher model.

    2. Define and initialize the student model.

    3. Construct a dataloader, an optimizer, and a learning rate scheduler.

  • Stage 2: Distillation with TextBrewer:

    1. Construct a TraningConfig and a DistillationConfig, initialize a distiller.

    2. Define an adaptor and a callback. The adaptor is used for the adaptation of model inputs and outputs. The callback is called by the distiller during training.

    3. Call the :train method of the distiller.

Quickstart

Here we show the usage of TextBrewer by distilling BERT-base to a 3-layer BERT.

Before distillation, we assume users have provided:

  • A trained teacher model teacher_model (BERT-base) and a to-be-trained student model student_model (3-layer BERT).

  • a dataloader of the dataset, an optimizer and a learning rate builder or class scheduler_class and its args dict scheduler_dict.

Distill with TextBrewer:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig

# Show the statistics of model parameters
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)

print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)

# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
    # The second and third elements of model outputs are the logits and hidden states
    return {'logits': model_outputs[1],
            'hidden': model_outputs[2]}

# Training configuration
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
# We match 0-0 and 8-2 here for demonstration
distill_config = DistillationConfig(
    intermediate_matches=[
    {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
    {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])

# Build distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model,
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

# Start!
with distiller:
    distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=None)

Examples

Examples can be found in the examples directory of the repo:

  • examples/random_token_example : a simple runnable toy example which demonstrates the usage of TextBrewer. This example performs distillation on the text classification task with random tokens as inputs.

  • examples/cmrc2018_example (Chinese): distillation on CMRC 2018, a Chinese MRC task, using DRCD as data augmentation.

  • examples/mnli_example (English): distillation on MNLI, an English sentence-pair classification task. This example also shows how to perform multi-teacher distillation.

  • examples/conll2003_example (English): distillation on CoNLL-2003 English NER task, which is in the form of sequence labeling.

  • examples/msra_ner_example (Chinese): This example distills a Chinese-ELECTRA-base model on the MSRA NER task with distributed data-parallel training(single node, muliti-GPU).

FAQ

Q: How to initialize the student model?

A: The student model could be randomly initialized (i.e., with no prior knowledge) or be initialized by pre-trained weights. For example, when distilling a BERT-base model to a 3-layer BERT, you could initialize the student model with RBT3 (for Chinese tasks) or the first three layers of BERT (for English tasks) to avoid cold start problem. We recommend that users use pre-trained student models whenever possible to fully take advantage of large-scale pre-training.

Q: How to set training hyperparameters for the distillation experiments?

A: Knowledge distillation usually requires more training epochs and a larger learning rate than training on the labeled dataset. For example, training SQuAD on BERT-base usually takes 3 epochs with lr=3e-5; however, distillation takes 30~50 epochs with lr=1e-4. The conclusions are based on our experiments, and you are advised to try on your own data.

Q: My teacher model and student model take different inputs (they do not share vocabularies), so how can I distill?

A: You need to feed different batches to the teacher and the student. See Feed Different batches to Student and Teacher, Feed Cached Values.

Q: I have stored the logits from my teacher model. Can I use them in the distillation to save the forward pass time?

A: Yes, see Feed Different batches to Student and Teacher, Feed Cached Values.

Known Issues

  • Multi-label classification is not supported.

Citation

If you find TextBrewer is helpful, please cite our paper :

@InProceedings{textbrewer-acl2020-demo,
  author =  "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
  title =   "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
  booktitle =   "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
  year =  "2020",
  publisher =   "Association for Computational Linguistics"
}

Core Concepts

Conventions

  • Model_T an instance of torch.nn.Module, the teacher model that to be distilled.

  • Model_S: an instance of torch.nn.Module, the student model, usually smaller than the teacher model for model compression and faster inference speed.

  • optimizer: an instance of torch.optim.Optimizer.

  • scheduler: an instance of a class under torch.optim.lr_scheduler, allows flexible adjustment of learning rate.

  • dataloader: data iterator, used to generate data batches. A batch can be a tuple or a dict

    for batch in dataloader:
        # if batch_postprocessor is not None:
        batch = batch_postprocessor(batch)
        # check batch datatype
        # passes batch to the model and adaptors
    

Batch Format(important)

Foward conventions: each batch to be passed to the model should be a dict or tuple:

  • if the batch is a dict, the model will be called as model(**batch, **args);

  • if the batch is a tuple, the model is called as model(*batch, **args).

Hence if the batch is not a dict, users should make sure that the order of each element in the batch is the same as the order of the arguments of model.forward. args is used for passing additional parameters.

Users can additionaly define a batch_postprocessor function to post-process batches if needed. batch_postprocessor should take a batch and return a batch. See the explanation on train method of Distillers for more details.

Since version 0.2.1, TextBrewer supports more flexible inputs scheme: users can feed different batches to student and teacher, or feed the cached values to save the forward pass time. See Feed Different batches to Student and Teacher, Feed Cached Values.

Configurations

Distillers

Distillers are in charge of conducting the actual experiments. The following distillers are available:

  • BasicDistiller: single-teacher single-task distillation, provides basic distillation strategies.

  • GeneralDistiller (Recommended): single-teacher single-task distillation, supports intermediate features matching. Recommended most of the time.

  • MultiTeacherDistiller: multi-teacher distillation, which distills multiple teacher models (of the same task) into a single student model. This class doesn’t support Intermediate features matching.

  • MultiTaskDistiller: multi-task distillation, which distills multiple teacher models (of different tasks) into a single student.

  • BasicTrainer: Supervised training a single model on a labeled dataset, not for distillation. It can be used to train a teacher model.

User-Defined Functions

In TextBrewer, there are two functions that should be implemented by users: callback() and adaptor() .

callback(model, step) → None

At each checkpoint, after saving the student model, the callback function will be called by the distiller. callback can be used to evaluate the performance of the student model at each checkpoint.

Note

If users want to do an evaluation in the callback, remember to add model.eval() in the callback.

Parameters
  • model (torch.nn.Module) – the student model

  • step (int) – the current training step

adaptor(batch, model_inputs) → dict

It converts the model inputs and outputs to the specified format so that they can be recognized by the distiller. At each training step, batch and model outputs will be passed to the adaptor; adaptor reorganize the data and returns a dict.

The functionality of the adaptor is shown in the figure below:

_images/adaptor.png
Parameters
  • batch – the input batch to the model

  • model_outputs – the outputs returned by the model

Return type

dict

Returns

a dictionary that may contain the following keys and values:

  • logits’ : List[torch.Tensor] or torch.Tensor

    The inputs to the final softmax. Each tensor should have the shape (batch_size, num_labels) or (batch_size, length, num_labels).

  • logits_mask’: List[torch.Tensor] or torch.Tensor

    0/1 matrix, which masks logits at specified positions. The positions where mask==0 won’t be included in the calculation of loss on logits. Each tensor should have the shape (batch_size, length).

  • labels’: List[torch.Tensor] or torch.Tensor

    Ground-truth labels of the examples. Each tensor should have the shape (batch_size,) or (batch_size, length).

Note

  • logits_mask only works for logits with shape (batch_size, length, num_labels). It’s used to mask along the length dimension, commonly used in sequence labeling tasks.

  • logits, logits_mask and labels should either all be lists of tensors, or all be tensors.

  • losses’ : List[torch.Tensor]

    It stores pre-computed losses, for example, the cross-entropy between logits and ground-truth labels. All the losses stored here would be summed and weighted by hard_label_weight and added to the total loss. Each tensor in the list should be a scalar.

  • attention’: List[torch.Tensor]

    List of attention matrices, used to compute intermediate feature matching loss. Each tensor should have the shape (batch_size, num_heads, length, length) or (batch_size, length, length), depending on what attention loss is used. Details about various loss functions can be found at Intermediate Losses.

  • hidden’: List[torch.Tensor]

    List of hidden states used to compute intermediate feature matching loss. Each tensor should have the shape (batch_size, length, hidden_dim).

  • inputs_mask’ : torch.Tensor

    0/1 matrix, performs masking on attention and hidden, should have the shape (batch_size, length).

Note

These keys are all optional:

  • If there is no inputs_mask or logits_mask, then it’s considered as no masking.

  • If not there is no intermediate feature matching loss, you can ignore attention and hidden.

  • If you don’t want to add loss of the original hard labels, you can set hard_label_weight=0 in the DistillationConfig and ignore losses.

  • If logits is not provided, the KD loss of the logits will be omitted.

  • labels is required if and only if probability_shift==True.

  • You shouldn’t ignore all the keys, otherwise the training won’t start :)

In most cases logits should be provided, unless you are doing multi-stage training or non-classification tasks, etc.

Example:

'''
Suppose the model outputs are: logits, sequence_output, total_loss
class MyModel():
  def forward(self, input_ids, attention_mask, labels, ...):
    ...
    return logits, sequence_output, total_loss

logits: Tensor of shape (batch_size, num_classes)
sequence_output: List of tensors of (batch_size, length, hidden_dim)
total_loss: scalar tensor

model inputs are:
input_ids      = batch[0] : input_ids (batch_size, length)
attention_mask = batch[1] : attention_mask (batch_size, length)
labels         = batch[2] : labels (batch_size, num_classes)
'''
def SimpleAdaptor(batch, model_outputs):
  return {'logits': (model_outputs[0],),
      'hidden': model.outputs[1],
      'inputs_mask': batch[1]}

Feed Different batches to Student and Teacher, Feed Cached Values

Feed Different batches

In some cases, student and teacher read different inputs. For example, if you distill a RoBERTa model to a BERT model, they cannot share the inputs since they have different vocabularies.

To solve this, one can build a dataset that returns a dict as the batch with keys 'student' and 'teacher'. TextBrewer will unpack the dict, and feeds batch['student'] to the student and its adaptor, feeds batch['teacher'] to the teacher and its adaptor, following the forward conventions.

Here is an example.

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader

class TSDataset(Dataset):
    def __init__(self, teacher_dataset, student_dataset):
        # teacher_dataset and student_dataset are normal datasets
        # whose each element is a tuple or a dict.
        assert len(teacher_dataset) == len(student_dataset), \
          f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!"

        self.teacher_dataset = teacher_dataset
        self.student_dataset = student_dataset

    def __len__(self):
        return len(self.teacher_dataset)

    def __getitem__(self,i):
        return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i]}

teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3))
student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2))
tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset)
dataloader = DataLoader(dataset=tsdataset, ... )

Feed Cached Values

If you are ready to provide a dataset that returns dict with keys 'student' and 'teacher' like the one above, you can also add a another key 'teacher_cache', which stores the pre-computed outputs from the teacher. Then TextBrewer will treat batch['teacher_cache'] as the output from the teacher and feed it to the teacher’s adaptor. No teacher’s forward will be called.

Here is an example.

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader

class TSDataset(Dataset):
    def __init__(self, teacher_dataset, student_dataset, teacher_cache):
        # teacher_dataset and student_dataset are normal datasets
        # whose each element is a tuple or a dict.
        # teacher_cache is a list of items; each item is the output from the teacher.
        assert len(teacher_dataset) == len(student_dataset), \
          f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!"
        assert len(teacher_dataset) == len(teacher_cache), \
          f"lengths of teacher_dataset {len(teacher_dataset)} and teacher_cache {len(teacher_cache)} are not the same!"
        self.teacher_dataset = teacher_dataset
        self.student_dataset = student_dataset
        self.teacher_cache = teacher_cache

    def __len__(self):
        return len(self.teacher_dataset)

    def __getitem__(self,i):
        return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i], 'teacher_cache':self.teacher_cache[i]}

teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3))
student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2))

# We make some fake data and assume teacher model outputs are (logits, loss)
fake_logits = [torch.randn(3) for _ in range(32)]
fake_loss = [torch.randn(1)[0] for _ in range(32)]
teacher_cache = [(fake_logits[i],fake_loss[i]) for i in range(32)]

tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset, teacher_cache=teacher_cache)
dataloader = DataLoader(dataset=tsdataset, ... )

Experiments

We have performed distillation experiments on several typical English and Chinese NLP datasets. The setups and configurations are listed below.

Models

We have tested different student models. To compare with public results, the student models are built with standard transformer blocks except for BiGRU which is a single-layer bidirectional GRU. The architectures are listed below. Note that the number of parameters includes the embedding layer but does not include the output layer of each specific task.

English models

Model #Layers Hidden size Feed-forward size #Params Relative size
BERT-base-cased (teacher) 12 768 3072 108M 100%
T6 (student) 6 768 3072 65M 60%
T3 (student) 3 768 3072 44M 41%
T3-small (student) 3 384 1536 17M 16%
T4-Tiny (student) 4 312 1200 14M 13%
T12-nano (student) 12 256 1024 17M 16%
BiGRU (student) - 768 - 31M 29%

Chinese models

Model #Layers Hidden size Feed-forward size #Params Relative size
RoBERTa-wwm-ext (teacher) 12 768 3072 102M 100%
Electra-base (teacher) 12 768 3072 102M 100%
T3 (student) 3 768 3072 38M 37%
T3-small (student) 3 384 1536 14M 14%
T4-Tiny (student) 4 312 1200 11M 11%
Electra-small (student) 12 256 1024 12M 12%

Configurations

Distillation Configurations

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
# Others arguments take the default values

matches are differnt for different models:

Model matches
BiGRU None
T6 L6_hidden_mse + L6_hidden_smmd
T3 L3_hidden_mse + L3_hidden_smmd
T3-small L3n_hidden_mse + L3_hidden_smmd
T4-Tiny L4t_hidden_mse + L4_hidden_smmd
T12-nano small_hidden_mse + small_hidden_smmd
Electra-small small_hidden_mse + small_hidden_smmd

The definitions of matches are at exmaple/matches/matches.py.

We use GeneralDistiller in all the distillation experiments.

Training Configurations

  • Learning rate is 1e-4 (unless otherwise specified).

  • We train all the models for 30~60 epochs.

Results on English Datasets

We experiment on the following typical Enlgish datasets:

Dataset Task type Metrics #Train #Dev Note
MNLI text classification m/mm Acc 393K 20K sentence-pair 3-class classification
SQuAD 1.1 reading comprehension EM/F1 88K 11K span-extraction machine reading comprehension
CoNLL-2003 sequence labeling F1 23K 6K named entity recognition

We list the public results from DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT and our results below for comparison.

Public results:

Model (public) MNLI SQuAD CoNLL-2003
DistilBERT (T6) 81.6 / 81.1 78.1 / 86.2 -
BERT6-PKD (T6) 81.5 / 81.0 77.1 / 85.3 -
BERT-of-Theseus (T6) 82.4/ 82.1 - -
BERT3-PKD (T3) 76.7 / 76.3 - -
TinyBERT (T4-tiny) 82.8 / 82.9 72.7 / 82.1 -

Our results (see Experimental Results for details):

Model (ours) MNLI SQuAD CoNLL-2003
BERT-base-cased (teacher) 83.7 / 84.0 81.5 / 88.6 91.1
BiGRU - - 85.3
T6 83.5 / 84.0 80.8 / 88.1 90.7
T3 81.8 / 82.7 76.4 / 84.9 87.5
T3-small 81.3 / 81.7 72.3 / 81.4 78.6
T4-tiny 82.0 / 82.6 75.2 / 84.0 89.1
T12-nano 83.2 / 83.9 79.0 / 86.6 89.6

Note:

  1. The equivalent model structures of public models are shown in the brackets after their names.

  2. When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD and HotpotQA is used for data augmentation on CoNLL-2003.

  3. When distilling to T12-nano, HotpotQA is used for data augmentation on CoNLL-2003.

Results on Chinese Datasets

We experiment on the following typical Chinese datasets:

Dataset Task type Metrics #Train #Dev Note
XNLI text classification Acc 393K 2.5K Chinese translation version of MNLI
LCQMC text classification Acc 239K 8.8K sentence-pair matching, binary classification
CMRC 2018 reading comprehension EM/F1 10K 3.4K span-extraction machine reading comprehension
DRCD reading comprehension EM/F1 27K 3.5K span-extraction machine reading comprehension (Traditional Chinese)
MSRA NER sequence labeling F1 45K 3.4K (test) Chinese named entity recognition

The results are listed below (see Experimental Results for details).

Model XNLI LCQMC CMRC 2018 DRCD
RoBERTa-wwm-ext (teacher) 79.9 89.4 68.8 / 86.4 86.5 / 92.5
T3 78.4 89.0 66.4 / 84.2 78.2 / 86.4
T3-small 76.0 88.1 58.0 / 79.3 75.8 / 84.8
T4-tiny 76.2 88.4 61.8 / 81.8 77.3 / 86.1
Model XNLI LCQMC CMRC 2018 DRCD MSRA NER
Electra-base (teacher) 77.8 89.8 65.6 / 84.7 86.9 / 92.3 95.14
Electra-small 77.7 89.3 66.5 / 84.9 85.5 / 91.3 93.48

Note:

  1. Learning rate decay is not used in distillation on CMRC 2018 and DRCD.

  2. CMRC 2018 and DRCD take each other as the augmentation dataset in the distillation.

  3. The settings of training Electra-base teacher model can be found at Chinese-ELECTRA.

  4. Electra-small student model is intialized with the pretrained weights.

Configurations

TrainingConfig

class textbrewer.TrainingConfig(gradient_accumulation_steps=1, ckpt_frequency=1, ckpt_epoch_frequency=1, ckpt_steps=None, log_dir=None, output_dir='./saved_models', device='cuda', fp16=False, fp16_opt_level='O1', data_parallel=False, local_rank=-1)[source]

Configurations related to general model training.

Parameters
  • gradient_accumulation_steps (int) – accumulates gradients before optimization to reduce GPU memory usage. It calls optimizer.step() every gradient_accumulation_steps backward steps.

  • ckpt_frequency (int) – stores model weights ckpt_frequency times every epoch.

  • ckpt_epoch_frequency (int) – stores model weights every ckpt_epoch_frequency epochs.

  • ckpt_steps (int) – if num_steps is passes to distiller.train(), saves the model every ckpt_steps, meanwhile ignore ckpt_frequency and ckpt_epoch_frequency .

  • log_dir (str) – directory to save the tensorboard log file. Set it to None to disable tensorboard.

  • output_dir (str) – directory to save model weights.

  • device (str or torch.device) – training on CPU or GPU.

  • fp16 (bool) – if True, enables mixed precision training using Apex.

  • fp16_opt_level (str) – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. See Apex documenation for details.

  • data_parallel (bool) – If True, wraps the models with torch.nn.DataParallel.

  • local_rank (int) – the local rank of the current processes. A non-nagative value means that we are in the distributed training mode with DistributedDataParallel.

Note

  • To perform data parallel (DP) training, you could either wrap the models with torch.nn.DataParallel outside TextBrewer by yourself, or leave the work for TextBrewer by setting data_parallel to True.

  • To enable both data parallel training and mixed precision training, you should set data_parallel to True, and DO NOT wrap the models by yourself.

  • In some experiments, we have observed slowing down in the speed with torch.nn.DataParallel.

  • To perform distributed data parallel (DDP) training, you should call torch.distributed.init_process_group before intializing a TrainingConfig; and pass the raw (unwrapped) model when initializing the distiller.

  • DP and DDP are mutual exclusive.

Example:

# Usually just need to set log_dir and output_dir and leave others default
train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir)

# Stores the model at the end of each epoch
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=1)
# Stores the model twice (at the middle and at the end) in each epoch
train_config = TrainingConfig(ckpt_frequency=2, ckpt_epoch_frequency=1)
# Stores the model once every two epochs
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=2)
classmethod from_dict(dict_object)

Construct configurations from a dict.

classmethod from_json_file(json_filename)

Construct configurations from a json file.

DistillationConfig

class textbrewer.DistillationConfig(temperature=4, temperature_scheduler='none', hard_label_weight=0, hard_label_weight_scheduler='none', kd_loss_type='ce', kd_loss_weight=1, kd_loss_weight_scheduler='none', probability_shift=False, intermediate_matches: Optional[List[Dict]] = None, is_caching_logits=False)[source]

Configurations related to distillation methods. It defines the total loss to be optimized:

\[\mathcal{L}_{total}= \mathcal{L}_{KD} * w_{KD} + \mathcal{L}_{hl} * w_{hl} + sum(\textrm{intermediate_losses})\]

where

  • \(\mathcal{L}_{KD}\) is the KD loss on logits, \(w_{KD}\) is its weight;

  • \(\mathcal{L}_{hl}\) is the sum of losses returned by the adaptor and \(w_{hl}\) is its weight;

  • intermediate_losses is defined via intermediate_matches.

Parameters
  • temperature (float) – temperature for the distillation. The teacher and student models’ logits will be divided by the temperature in computing the KD loss. The temperature typicially ranges from 1 to 10. We found that temperature higher than 1 usually leads to better performance.

  • temperature_scheduler – dynamically adjusts temperature. See TEMPERATURE_SCHEDULER for all available options.

  • kd_loss_type (str) – KD loss function for the logits term returned by the adaptor, can be 'ce' or 'mse'. See KD_LOSS_MAP.

  • kd_loss_weight (float) – the weight for the KD loss.

  • hard_label_weight (float) – the weight for the sum of losses term returned by the adaptor. losses may include the losses on the ground-truth labels and other user-defined losses.

  • kd_loss_weight_scheduler – Dynamically adjusts KD loss weight. See WEIGHT_SCHEDULER for all available options.

  • hard_label_weight_scheduler – Dynamically adjusts the weight of the sum of losses. See WEIGHT_SCHEDULER for all available options.

  • probability_shift (bool) – if True, switch the ground-truth label’s logit and the largest logit predicted by the teacher, to make the ground-truth label’s logit largest. Requires labels term returned by the adaptor.

  • is_caching_logits (bool) – if True, caches the batches and the output logits of the teacher model in memory, so that those logits will only be calcuated once. It will speed up the distillation process. This feature is only available for BasicDistiller and MultiTeacherDistiller, and only when distillers’ train() method is called with num_steps=None. It is suitable for small and medium datasets.

  • intermediate_matches (List[Dict]) – Configuration for intermediate feature matching. Each element in the list is a dict, representing a pair of matching config.

The dict in intermediate_matches contains the following keys:

  • layer_T’: layer_T (int): selects the layer_T-th layer of teacher model.

  • layer_S’: layer_S (int): selects the layer_S-th layer of student model.

Note

  1. layer_T and layer_S indicate layers in attention or hidden list in the returned dict of the adaptor, rather than the actual layers in the model.

  2. If the loss is fst or nst, two layers have to be chosen from the teacher and the student respectively. In this case, layer_T and layer_S are lists of two ints. See the example below.

  • feature’: feature (str): features of intermediate layers. It can be:

    • attention’ : attention matrix, of the shape (batch_size, num_heads, length, length) or (batch_size, length, length)

    • hidden’:hidden states, of the shape (batch_size, length, hidden_dim).

  • loss’ : loss (str) : loss function. See MATCH_LOSS_MAP for available losses. Currently includes: 'attention_mse', 'attention_ce', 'hidden_mse', 'nst', etc.

  • weight’: weight (float) : weight for the loss.

  • proj’ : proj (List, optional) : if the teacher and the student have the same feature dimension, it is optional; otherwise it is required. It is the mapping function to match teacher and student intermediate feature dimension. It is a list, with these elements:

    • proj[0] (str): mapping function, can be 'linear', 'relu', 'tanh'. See PROJ_MAP.

    • proj[1] (int): feature dimension of student model.

    • proj[2] (int): feature dimension of teacher model.

    • proj[3] (dict): optional, provides configurations such as learning rate. If not provided, the learning rate and optimizer configurations will follow the default config of the optimizer, otherwise it will use the ones specified here.

Example:

from textbrewer import DistillationConfig

# simple configuration: use default values, or try different temperatures
distill_config = DistillationConfig(temperature=8)

# adding intermediate feature matching
# under this setting, the returned dict results_T/S of adaptor_T/S should contain 'hidden' key.
# The mse loss between teacher's results_T['hidden'][10] and student's results_S['hidden'][3] will be computed
distill_config = DistillationConfig(
    temperature=8,
    intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden', 'loss':'hidden_mse', 'weight':1}]
)

# multiple inatermediate feature matching. The teacher and the student have a  hidden_dim of 768 and 384 respectively.
distill_config = DistillationConfig(
    temperature = 8,
    intermediate_matches = [ \
    {'layer_T':0,  'layer_S':0, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
    {'layer_T':4,  'layer_S':1, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
    {'layer_T':8,  'layer_S':2, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
    {'layer_T':12, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}]
)
classmethod from_dict(dict_object)

Construct configurations from a dict.

classmethod from_json_file(json_filename)

Construct configurations from a json file.

Distillers

Distillers perform the actual experiments.

Initialize a distiller object, call its train method to start training/distillation.

BasicDistiller

class textbrewer.BasicDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

Performs single-teacher single-task distillation, provides basic distillation strategies.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (torch.nn.Module) – teacher model.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloader – dataset iterator.

  • num_epochs (int) – number of training epochs.

  • num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to evaluate the model at each checkpoint.

  • batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • **args – additional arguments fed to the model.

Note

  • If the batch is a list or tuple, model is called as: model(*batch, **args). Make sure the order of elements in the batch matches their order in model.forward.

  • If the batch is a dict, model is called as: model(**batch,**args). Make sure the keys of the batch match the arguments of the model.forward.

Note

If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:

from transformers import get_linear_schedule_with_warmup
distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})

GeneralDistiller

class textbrewer.GeneralDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S, custom_matches: Optional[List[textbrewer.distiller_utils.CustomMatch]] = None)[source]

Supports intermediate features matching. Recommended for single-teacher single-task distillation.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (torch.nn.Module) – teacher model.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

  • custom_matches (list) – supports more flexible user-defined matches (testing).

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloader – dataset iterator.

  • num_epochs (int) – number of training epochs.

  • num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to evaluate the model at each checkpoint.

  • batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • **args – additional arguments fed to the model.

Note

  • If the batch is a list or tuple, model is called as: model(*batch, **args). Make sure the order of elements in the batch matches their order in model.forward.

  • If the batch is a dict, model is called as: model(**batch,**args). Make sure the keys of the batch match the arguments of the model.forward.

Note

If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:

from transformers import get_linear_schedule_with_warmup
distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})

MultiTeacherDistiller

class textbrewer.MultiTeacherDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

Distills multiple teacher models (of the same tasks) into a student model. It doesn’t support intermediate feature matching.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (List[torch.nn.Module]) – list of teacher models.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(self, optimizer, scheduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args)

trains the student model. See BasicDistiller.train().

MultiTaskDistiller

class textbrewer.MultiTaskDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

distills multiple teacher models (of different tasks) into a single student. It supports intermediate feature matching since 0.2.1.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (dict) – dict of teacher models: {task1:model1, task2:model2, …. }. Keys are tasknames.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (dict) – dict of teacher adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.

  • adaptor_S (dict) – dict of student adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.

train(optimizer, dataloaders, num_steps, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, tau=1, callback=None, batch_postprocessors=None, **args)[source]

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloaders (dict) – dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.

  • num_steps (int) – number of training steps.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • tau (float) – the probability of sampling an example from task d is proportional to |d|^{tau}, where |d| is the size of d’s training set. If the size of any dataset is unknown, ignores tau and samples examples unifromly from each dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to do evaluation of the model at each checkpoint.

  • batch_postprocessors (dict) – a dict of batch_postprocessors. Keys are tasknames, values are corresponding batch_postprocessors. Each batch_postprocessor should take a batch and return a batch.

  • **args – additional arguments fed to the model.

BasicTrainer

class textbrewer.BasicTrainer(train_config: textbrewer.configurations.TrainingConfig, model: torch.nn.modules.module.Module, adaptor)[source]

It performs supervised training, not distillation. It can be used for training the teacher model.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • model (torch.nn.Module) – model to be trained.

  • adaptor (Callable) –

The role of adaptor is explained in adaptor().

train(optimizer, dataloader, num_epochs, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]

trains the model. See BasicDistiller.train().

Presets

Presets include module variables that define pre-defined loss functions and strategies.

Module variables

ADAPTOR_KEYS

textbrewer.presets.ADAPTOR_KEYS

(list) valid keys of the dict returned by the adaptor, includes:

  • logits

  • logits_mask

  • losses

  • inputs_mask

  • labels

  • hidden

  • attention

KD_LOSS_MAP

textbrewer.presets.KD_LOSS_MAP

(dict) available KD losses

  • mse’ : mean squared error

  • ce’: cross-entropy loss

PROJ_MAP

textbrewer.presets.PROJ_MAP

(dict) layers used to match the different dimensions of intermediate features

  • linear’ : linear layer, no activation

  • relu’ : ReLU activation

  • tanh’: Tanh activation

MATCH_LOSS_MAP

textbrewer.presets.MATCH_LOSS_MAP

(dict) intermediate feature matching loss functions, includes:

See Intermediate Losses for details.

WEIGHT_SCHEDULER

textbrewer.presets.WEIGHT_SCHEDULER

(dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight.

  • linear_decay’ : decay from 1 to 0 during the whole training process.

  • linear_growth’ : grow from 0 to 1 during the whole training process.

TEMPERATURE_SCHEDULER

textbrewer.presets.TEMPERATURE_SCHEDULER

(custom dict) used to dynamically adjust distillation temperature.

Different from other options, when using 'flsw' and 'cwsm', you need to provide extra parameters, for example:

#flsw
distill_config = DistillationConfig(
    temperature_scheduler = ['flsw', 1, 2]  # beta=1, gamma=2
)

#cwsm
distill_config = DistillationConfig(
    temperature_scheduler = ['cwsm', 1] # beta = 1
)

Customization

If the pre-defined modules do not satisfy your requirements, you can add your own defined modules to the above dict.

For example:

MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss
WEIGHT_SCHEDULER['my_weight_scheduler'] = my_weight_scheduler

then used in DistillationConfig:

distill_config = DistillationConfig(
kd_loss_weight_scheduler = 'my_weight_scheduler'
intermediate_matches = [{'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'my_L1_loss', 'weight' : 1}]
...)

Refer to the source code for more details on inputs and outputs conventions (will be explained in detail in a later version of the documentation).

Intermediate Losses

Here we list the definitions of pre-defined intermediate losses. Usually, users don’t need to refer to these functions directly, but refer to them by the names in MATCH_LOSS_MAP.

attention_mse

textbrewer.losses.att_mse_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the mse loss between attention_S and attention_T.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_mse_sum

textbrewer.losses.att_mse_sum_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the mse loss between attention_S and attention_T.

  • If the the shape is (batch_size, num_heads, length, length), sums along the num_heads dimension and then calcuates the mse loss between the two matrices.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_ce

textbrewer.losses.att_ce_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on dim=-1.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_ce_mean

textbrewer.losses.att_ce_mean_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on dim=-1.

  • If the shape is (batch_size, num_heads, length, length), averages over dimension num_heads and then computes cross-entropy loss between the two matrics.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • logits_T (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • mask (torch.tensor) – tensor of shape (batch_size, length)

hidden_mse

textbrewer.losses.hid_mse_loss(state_S, state_T, mask=None)[source]
  • Calculates the mse loss between state_S and state_T, which are the hidden state of the models.

  • If the inputs_mask is given, masks the positions where input_mask==0.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

cos

textbrewer.losses.cos_loss(state_S, state_T, mask=None)[source]
  • Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see DistilBERT

  • If the inputs_mask is given, masks the positions where input_mask==0.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

pkd

textbrewer.losses.pkd_loss(state_S, state_T, mask=None)[source]
  • Computes normalized vector mse loss at position 0 along length dimension. This is the loss used in BERT-PKD, see Patient Knowledge Distillation for BERT Model Compression.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

  • If the input tensors are of shape (batch_size, hidden_size), it directly computes the loss between tensors without taking the hidden states at position 0.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)

  • mask – not used.

nst (mmd)

textbrewer.losses.mmd_loss(state_S, state_T, mask=None)[source]
  • Takes in two lists of matrices state_S and state_T. Each list contains 2 matrices of the shape (batch_size, length, hidden_size). hidden_size of matrices in State_S doesn’t need to be the same as that of state_T. Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, length, length) ) and the ones in B ( with the resulting shape (batch_size, length, length) ), then computes the mse loss between the similarity matrices:

\[loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)\]
Parameters
  • state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • mask (torch.tensor) – tensor of the shape (batch_size, length)

Example in intermediate_matches:

intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1},
...]

fsp (gram)

textbrewer.losses.fsp_loss(state_S, state_T, mask=None)[source]
  • Takes in two lists of matrics state_S and state_T. Each list contains two matrices of the shape (batch_size, length, hidden_size). Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, hidden_size, hidden_size) ) and the ones in B ( with the resulting shape (batch_size, hidden_size, hidden_size) ), then computes the mse loss between the similarity matrices:

\[loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)\]
Parameters
  • state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • mask (torch.tensor) – tensor of the shape (batch_size, length)

Example in intermediate_matches:

intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]},
...]

Model Utils

display_parameters

textbrewer.utils.display_parameters(model, max_level=None)[source]

Display the numbers and memory usage of module parameters.

Parameters
  • model (torch.nn.Module or dict) – the model to be inspected.

  • max_level (int or None) – The max level to display. If max_level==None, show all the levels.

Returns

A formatted string and a LayerNode object representing the model.

Data Utils

This module provides the following data augmentation methods.

masking

textbrewer.data_utils.masking(tokens, p=0.1, mask='[MASK]')[source]

Returns a new list by replacing elements in tokens by mask with probability p.

Parameters
  • tokens (list) – list of tokens or token ids.

  • p (float) – probability to mask each element in tokens.

Returns

A new list by replacing elements in tokens by mask with probability p.

deleting

textbrewer.data_utils.deleting(tokens, p=0.1)[source]

Returns a new list by deleting elements in tokens with probability p.

Parameters
  • tokens (list) – list of tokens or token ids.

  • p (float) – probability to delete each element in tokens.

Retunrns:

a new list by deleting elements in :tokens with probability p.

n_gram_sampling

textbrewer.data_utils.n_gram_sampling(tokens, p_ng=[0.2, 0.2, 0.2, 0.2, 0.2], l_ng=[1, 2, 3, 4, 5])[source]

Samples a length l from l_ng with probability distribution p_ng, then returns a random span of length l from tokens.

Parameters
  • tokens (list) – list of tokens or token ids.

  • p_ng (list) – probability distribution of the n-grams, should sum to 1.

  • l_ng (list) – specify the n-grams.

Returns

a n-gram random span from tokens.

short_disorder

textbrewer.data_utils.short_disorder(tokens, p=[0.9, 0.1, 0, 0, 0])[source]

Returns a new list by disordering tokens with probability distribution p at every possible position. Let abc be a 3-gram in tokens, there are five ways to disorder, corresponding to five probability values:

abc -> abc
abc -> bac
abc -> cba
abc -> cab
abc -> bca
Parameters
  • tokens (list) – list of tokens or token ids.

  • p (list) – probability distribution of 5 disorder types, should sum to 1.

Returns

a new disordered list

long_disorder

textbrewer.data_utils.long_disorder(tokens, p=0.1, length=20)[source]

Performs a long-range disordering. If length>1, then swaps the two halves of each span of length length in tokens; if length<=1, treats length as the relative length. For example:

>>>long_disorder([0,1,2,3,4,5,6,7,8,9,10], p=1, length=0.4)
[2, 3, 0, 1, 6, 7, 4, 5, 8, 9]
Parameters
  • tokens (list) – list of tokens or token ids.

  • p (list) – probability to swaps the two halves of a spans at possible positions.

  • length (int or float) – length of the disordered span.

Returns

a new disordered list

Experimental Results

English Datasets

MNLI

  • Training without Distillation:

Model(ours) MNLI
BERT-base-cased 83.7 / 84.0
T3 76.1 / 76.5
  • Single-teacher distillation with GeneralDistiller:

Model (ours) MNLI
BERT-base-cased (teacher) 83.7 / 84.0
T6 (student) 83.5 / 84.0
T3 (student) 81.8 / 82.7
T3-small (student) 81.3 / 81.7
T4-tiny (student) 82.0 / 82.6
T12-nano (student) 83.2 / 83.9
  • Multi-teacher distillation with MultiTeacherDistiller:

Model (ours) MNLI
BERT-base-cased (teacher #1) 83.7 / 84.0
BERT-base-cased (teacher #2) 83.6 / 84.2
BERT-base-cased (teacher #3) 83.7 / 83.8
ensemble (average of #1, #2 and #3) 84.3 / 84.7
BERT-base-cased (student) 84.8 / 85.3

SQuAD

  • Training without Distillation:

Model(ours) SQuAD
BERT-base-cased 81.5 / 88.6
T6 75.0 / 83.3
T3 63.0 / 74.3
  • Single-teacher distillation with GeneralDistiller:

Model(ours) SQuAD
BERT-base-cased (teacher) 81.5 / 88.6
T6 (student) 80.8 / 88.1
T3 (student) 76.4 / 84.9
T3-small (student) 72.3 / 81.4
T4-tiny (student) 73.7 / 82.5
  + DA 75.2 / 84.0
T12-nano (student) 79.0 / 86.6

Note: When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD.

  • Multi-teacher distillation with MultiTeacherDistiller:

Model (ours) SQuAD
BERT-base-cased (teacher #1) 81.1 / 88.6
BERT-base-cased (teacher #2) 81.2 / 88.5
BERT-base-cased (teacher #3) 81.2 / 88.7
ensemble (average of #1, #2 and #3) 82.3 / 89.4
BERT-base-cased (student) 83.5 / 90.0

CoNLL-2003 English NER

  • Training without Distillation:

Model(ours) CoNLL-2003
BERT-base-cased 91.1
BiGRU 81.1
T3 85.3
  • Single-teacher distillation with GeneralDistiller:

Model(ours) CoNLL-2003
BERT-base-cased (teacher) 91.1
BiGRU 85.3
T6 (student) 90.7
T3 (student) 87.5
  + DA 90.0
T3-small (student) 78.6
  + DA -
T4-tiny (student) 77.5
  + DA 89.1
T12-nano (student) 78.8
  + DA 89.6

Note: HotpotQA is used for data augmentation on CoNLL-2003.

Chinese Datasets (RoBERTa-wwm-ext as the teacher)

XNLI

Model XNLI
RoBERTa-wwm-ext (teacher) 79.9
T3 (student) 78.4
T3-small (student) 76.0
T4-tiny (student) 76.2

LCQMC

Model LCQMC
RoBERTa-wwm-ext (teacher) 89.4
T3 (student) 89.0
T3-small (student) 88.1
T4-tiny (student) 88.4

CMRC 2018 and DRCD

Model CMRC 2018 DRCD
RoBERTa-wwm-ext (teacher) 68.8 / 86.4 86.5 / 92.5
T3 (student) 63.4 / 82.4 76.7 / 85.2
  + DA 66.4 / 84.2 78.2 / 86.4
T3-small (student) 46.1 / 71.0 71.4 / 82.2
  + DA 58.0 / 79.3 75.8 / 84.8
T4-tiny (student) 54.3 / 76.8 75.5 / 84.9
  + DA 61.8 / 81.8 77.3 / 86.1

Note: CMRC 2018 and DRCD take each other as the augmentation dataset on the experiments.

Chinese Datasets (Electra-base as the teacher)

  • Training without Distillation:

Model XNLI LCQMC CMRC 2018 DRCD MSRA NER
Electra-base (teacher) 77.8 89.8 65.6 / 84.7 86.9 / 92.3 95.14
Electra-small (pretrained) 72.5 86.3 62.9 / 80.2 79.4 / 86.4
  • Single-teacher distillation with GeneralDistiller:

Model XNLI LCQMC CMRC 2018 DRCD MSRA NER
Electra-base (teacher) 77.8 89.8 65.6 / 84.7 86.9 / 92.3 95.14
Electra-small (random) 77.2 89.0 66.5 / 84.9 84.8 / 91.0
Electra-small (pretrained) 77.7 89.3 66.5 / 84.9 85.5 / 91.3 93.48

Note:

  1. Random: randomly initialized

  2. Pretrained: initialized with pretrained weights

A good initialization of the student (Electra-small) improves the performance.

Indices and tables