Distillers¶
Distillers perform the actual experiments.
Initialize a distiller object, call its train method to start training/distillation.
BasicDistiller¶
-
class
textbrewer.
BasicDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ Performs single-teacher single-task distillation, provides basic distillation strategies.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (
torch.nn.Module
) – teacher model.model_S (
torch.nn.Module
) – student model.adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloader – dataset iterator.
num_epochs (int) – number of training epochs.
num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to evaluate the model at each checkpoint.batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
**args – additional arguments fed to the model.
Note
If the batch is a list or tuple, model is called as:
model(*batch, **args)
. Make sure the order of elements in the batch matches their order inmodel.forward
.If the batch is a dict, model is called as:
model(**batch,**args)
. Make sure the keys of the batch match the arguments of themodel.forward
.
Note
If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:
from transformers import get_linear_schedule_with_warmup distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})
GeneralDistiller¶
-
class
textbrewer.
GeneralDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S, custom_matches: Optional[List[textbrewer.distiller_utils.CustomMatch]] = None)[source]¶ Supports intermediate features matching. Recommended for single-teacher single-task distillation.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (
torch.nn.Module
) – teacher model.model_S (
torch.nn.Module
) – student model.adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
custom_matches (list) – supports more flexible user-defined matches (testing).
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloader – dataset iterator.
num_epochs (int) – number of training epochs.
num_steps (int) – number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps, and dataloader can have an unkonwn size, i.e., has no __len__ attribute. Dataloader will be cycled automatically after iterating over the whole dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to evaluate the model at each checkpoint.batch_postprocessor (Callable) – a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object. See the example below.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
**args – additional arguments fed to the model.
Note
If the batch is a list or tuple, model is called as:
model(*batch, **args)
. Make sure the order of elements in the batch matches their order inmodel.forward
.If the batch is a dict, model is called as:
model(**batch,**args)
. Make sure the keys of the batch match the arguments of themodel.forward
.
Note
If you want to provide a lr scheduler, DON’T USE scheduler , use scheduler_class and scheduler_args instead. Example:
from transformers import get_linear_schedule_with_warmup distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000})
MultiTeacherDistiller¶
-
class
textbrewer.
MultiTeacherDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ Distills multiple teacher models (of the same tasks) into a student model. It doesn’t support intermediate feature matching.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (List[torch.nn.Module]) – list of teacher models.
model_S (torch.nn.Module) – student model.
adaptor_T (Callable) – teacher model’s adaptor.
adaptor_S (Callable) – student model’s adaptor.
The roles of adaptor_T and adaptor_S are explained in
adaptor()
.-
train
(self, optimizer, scheduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args)¶ trains the student model. See
BasicDistiller.train()
.
MultiTaskDistiller¶
-
class
textbrewer.
MultiTaskDistiller
(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)[source]¶ distills multiple teacher models (of different tasks) into a single student. It supports intermediate feature matching since 0.2.1.
- Parameters
train_config (
TrainingConfig
) – training configuration.distill_config (
DistillationConfig
) – distillation configuration.model_T (dict) – dict of teacher models: {task1:model1, task2:model2, …. }. Keys are tasknames.
model_S (torch.nn.Module) – student model.
adaptor_T (dict) – dict of teacher adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.
adaptor_S (dict) – dict of student adaptors: {task1:adpt1, task2:adpt2, …. }. Keys are tasknames.
-
train
(optimizer, dataloaders, num_steps, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, tau=1, callback=None, batch_postprocessors=None, **args)[source]¶ trains the student model.
- Parameters
optimizer – optimizer.
dataloaders (dict) – dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
num_steps (int) – number of training steps.
scheduler_class (class) – the class of the scheduler to be constructed.
scheduler_args (dict) – arguments (excluding optimizer) passed to the scheduler_class to construct the scheduler object.
scheduler (deprecated) – used to adjust learning rate, optional, can be None, is deprecated in favor of scheduler_class and scheduler_args.
max_grad_norm (float) – Maximum norm for the gradients (-1 means no clipping). Default: -1.0
tau (float) – the probability of sampling an example from task d is proportional to |d|^{tau}, where |d| is the size of d’s training set. If the size of any dataset is unknown, ignores tau and samples examples unifromly from each dataset.
callback (Callable) – function called after each epoch, can be None. It is called as
callback(model=self.model_S, step = global_step)
. It can be used to do evaluation of the model at each checkpoint.batch_postprocessors (dict) – a dict of batch_postprocessors. Keys are tasknames, values are corresponding batch_postprocessors. Each batch_postprocessor should take a batch and return a batch.
**args – additional arguments fed to the model.
BasicTrainer¶
-
class
textbrewer.
BasicTrainer
(train_config: textbrewer.configurations.TrainingConfig, model: torch.nn.modules.module.Module, adaptor)[source]¶ It performs supervised training, not distillation. It can be used for training the teacher model.
- Parameters
train_config (
TrainingConfig
) – training configuration.model (
torch.nn.Module
) – model to be trained.adaptor (Callable) –
The role of adaptor is explained in
adaptor()
.-
train
(optimizer, dataloader, num_epochs, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm=-1.0, num_steps=None, callback=None, batch_postprocessor=None, **args)[source]¶ trains the model. See
BasicDistiller.train()
.