Distillers

Distillers perform the actual experiments.

Initialize a distiller object, call its train method to start training/distillation.

BasicDistiller

class textbrewer.BasicDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

Performs single-teacher single-task distillation, provides basic distillation strategies.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (torch.nn.Module) – teacher model.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloader – dataset iterator.

  • num_epochs (int) – number of training epochs.

  • num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to evaluate the model at each checkpoint.

  • batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • **args – additional arguments fed to the model.

Note

  • If the batch is a list or tuple, model is called as: model(*batch, **args). Make sure the order of elements in the batch matches their order in model.forward.

  • If the batch is a dict, model is called as: model(**batch,**args). Make sure the keys of the batch match the arguments of the model.forward.

Note

If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:

from transformers import get_linear_schedule_with_warmup
distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})

GeneralDistiller

class textbrewer.GeneralDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S, custom_matches: Optional[List[textbrewer.distiller_utils.CustomMatch]] = None)[source]

Supports intermediate features matching. Recommended for single-teacher single-task distillation.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (torch.nn.Module) – teacher model.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

  • custom_matches (list) – supports more flexible user-defined matches (testing).

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloader – dataset iterator.

  • num_epochs (int) – number of training epochs.

  • num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to evaluate the model at each checkpoint.

  • batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • **args – additional arguments fed to the model.

Note

  • If the batch is a list or tuple, model is called as: model(*batch, **args). Make sure the order of elements in the batch matches their order in model.forward.

  • If the batch is a dict, model is called as: model(**batch,**args). Make sure the keys of the batch match the arguments of the model.forward.

Note

If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:

from transformers import get_linear_schedule_with_warmup
distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})

MultiTeacherDistiller

class textbrewer.MultiTeacherDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

Distills multiple teacher models (of the same tasks) into a student model. It doesn’t support intermediate feature matching.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (List[torch.nn.Module]) – list of teacher models.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (Callable) – teacher model’s adaptor.

  • adaptor_S (Callable) – student model’s adaptor.

The roles of adaptor_T and adaptor_S are explained in adaptor().

train(self, optimizer, scheduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args)

trains the student model. See BasicDistiller.train().

MultiTaskDistiller

class textbrewer.MultiTaskDistiller(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]

distills multiple teacher models (of different tasks) into a single student. It supports intermediate feature matching since 0.2.1.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • distill_config (DistillationConfig) – distillation configuration.

  • model_T (dict) – dict of teacher models: {task1:model1, task2:model2, …. }. Keys are tasknames.

  • model_S (torch.nn.Module) – student model.

  • adaptor_T (dict) – dict of teacher adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.

  • adaptor_S (dict) – dict of student adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.

train(optimizer, dataloaders, num_steps, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, tau=1, callback=None, batch_postprocessors=None, **args)[source]

trains the student model.

Parameters
  • optimizer – optimizer.

  • dataloaders (dict) – dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.

  • num_steps (int) – number of training steps.

  • scheduler_class (class) – the class of the scheduler to be constructed.

  • scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object.

  • scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.

  • max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0

  • tau (float) – the probability of sampling an example from task d is proportional to |d|^{tau}, where |d| is the size of d’s training set. If the size of any dataset is unknown, ignores tau and samples examples unifromly from each dataset.

  • callback (Callable) – function called after each epoch, can be None. It is called as callback(model=self.model_S, step = global_step). It can be used to do evaluation of the model at each checkpoint.

  • batch_postprocessors (dict) – a dict of batch_postprocessors. Keys are tasknames, values are corresponding batch_postprocessors. Each batch_postprocessor should take a batch and return a batch.

  • **args – additional arguments fed to the model.

BasicTrainer

class textbrewer.BasicTrainer(train_config: textbrewer.configurations.TrainingConfig, model: torch.nn.modules.module.Module, adaptor)[source]

It performs supervised training, not distillation. It can be used for training the teacher model.

Parameters
  • train_config (TrainingConfig) – training configuration.

  • model (torch.nn.Module) – model to be trained.

  • adaptor (Callable) –

The role of adaptor is explained in adaptor().

train(optimizer, dataloader, num_epochs, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]

trains the model. See BasicDistiller.train().