Configurations¶
TrainingConfig¶
-
class
textbrewer.
TrainingConfig
(gradient_accumulation_steps=1, ckpt_frequency=1, ckpt_epoch_frequency=1, ckpt_steps=None, log_dir=None, output_dir='./saved_models', device='cuda', fp16=False, fp16_opt_level='O1', data_parallel=False, local_rank=-1)[source]¶ Configurations related to general model training.
- Parameters
gradient_accumulation_steps (int) – accumulates gradients before optimization to reduce GPU memory usage. It calls
optimizer.step()
every gradient_accumulation_steps backward steps.ckpt_frequency (int) – stores model weights ckpt_frequency times every epoch.
ckpt_epoch_frequency (int) – stores model weights every ckpt_epoch_frequency epochs.
ckpt_steps (int) – if num_steps is passes to
distiller.train()
, saves the model every ckpt_steps, meanwhile ignore ckpt_frequency and ckpt_epoch_frequency .log_dir (str) – directory to save the tensorboard log file. Set it to
None
to disable tensorboard.output_dir (str) – directory to save model weights.
device (str or torch.device) – training on CPU or GPU.
fp16 (bool) – if
True
, enables mixed precision training using Apex.fp16_opt_level (str) – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”. See Apex documenation for details.
data_parallel (bool) – If
True
, wraps the models withtorch.nn.DataParallel
.local_rank (int) – the local rank of the current processes. A non-nagative value means that we are in the distributed training mode with
DistributedDataParallel
.
Note
To perform data parallel (DP) training, you could either wrap the models with
torch.nn.DataParallel
outside TextBrewer by yourself, or leave the work for TextBrewer by setting data_parallel toTrue
.To enable both data parallel training and mixed precision training, you should set data_parallel to
True
, and DO NOT wrap the models by yourself.In some experiments, we have observed slowing down in the speed with
torch.nn.DataParallel
.To perform distributed data parallel (DDP) training, you should call
torch.distributed.init_process_group
before intializing a TrainingConfig; and pass the raw (unwrapped) model when initializing the distiller.DP and DDP are mutual exclusive.
Example:
# Usually just need to set log_dir and output_dir and leave others default train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir) # Stores the model at the end of each epoch train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=1) # Stores the model twice (at the middle and at the end) in each epoch train_config = TrainingConfig(ckpt_frequency=2, ckpt_epoch_frequency=1) # Stores the model once every two epochs train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=2)
-
classmethod
from_dict
(dict_object)¶ Construct configurations from a dict.
-
classmethod
from_json_file
(json_filename)¶ Construct configurations from a json file.
DistillationConfig¶
-
class
textbrewer.
DistillationConfig
(temperature=4, temperature_scheduler='none', hard_label_weight=0, hard_label_weight_scheduler='none', kd_loss_type='ce', kd_loss_weight=1, kd_loss_weight_scheduler='none', probability_shift=False, intermediate_matches: Optional[List[Dict]] = None, is_caching_logits=False)[source]¶ Configurations related to distillation methods. It defines the total loss to be optimized:
\[\mathcal{L}_{total}= \mathcal{L}_{KD} * w_{KD} + \mathcal{L}_{hl} * w_{hl} + sum(\textrm{intermediate_losses})\]where
\(\mathcal{L}_{KD}\) is the KD loss on logits, \(w_{KD}\) is its weight;
\(\mathcal{L}_{hl}\) is the sum of
losses
returned by the adaptor and \(w_{hl}\) is its weight;intermediate_losses is defined via intermediate_matches.
- Parameters
temperature (float) – temperature for the distillation. The teacher and student models’ logits will be divided by the temperature in computing the KD loss. The temperature typicially ranges from 1 to 10. We found that temperature higher than 1 usually leads to better performance.
temperature_scheduler – dynamically adjusts temperature. See
TEMPERATURE_SCHEDULER
for all available options.kd_loss_type (str) – KD loss function for the
logits
term returned by the adaptor, can be'ce'
or'mse'
. SeeKD_LOSS_MAP
.kd_loss_weight (float) – the weight for the KD loss.
hard_label_weight (float) – the weight for the sum of
losses
term returned by the adaptor.losses
may include the losses on the ground-truth labels and other user-defined losses.kd_loss_weight_scheduler – Dynamically adjusts KD loss weight. See
WEIGHT_SCHEDULER
for all available options.hard_label_weight_scheduler – Dynamically adjusts the weight of the sum of
losses
. SeeWEIGHT_SCHEDULER
for all available options.probability_shift (bool) – if
True
, switch the ground-truth label’s logit and the largest logit predicted by the teacher, to make the ground-truth label’s logit largest. Requireslabels
term returned by the adaptor.is_caching_logits (bool) – if
True
, caches the batches and the output logits of the teacher model in memory, so that those logits will only be calcuated once. It will speed up the distillation process. This feature is only available forBasicDistiller
andMultiTeacherDistiller
, and only when distillers’train()
method is called withnum_steps=None
. It is suitable for small and medium datasets.intermediate_matches (List[Dict]) – Configuration for intermediate feature matching. Each element in the list is a dict, representing a pair of matching config.
The dict in intermediate_matches contains the following keys:
‘layer_T’: layer_T (int): selects the layer_T-th layer of teacher model.
‘layer_S’: layer_S (int): selects the layer_S-th layer of student model.
Note
layer_T and layer_S indicate layers in
attention
orhidden
list in the returned dict of the adaptor, rather than the actual layers in the model.If the loss is
fst
ornst
, two layers have to be chosen from the teacher and the student respectively. In this case, layer_T and layer_S are lists of two ints. See the example below.
‘feature’: feature (str): features of intermediate layers. It can be:
‘attention’ : attention matrix, of the shape (batch_size, num_heads, length, length) or (batch_size, length, length)
‘hidden’:hidden states, of the shape (batch_size, length, hidden_dim).
‘loss’ : loss (str) : loss function. See
MATCH_LOSS_MAP
for available losses. Currently includes:'attention_mse'
,'attention_ce'
,'hidden_mse'
,'nst'
, etc.‘weight’: weight (float) : weight for the loss.
‘proj’ : proj (List, optional) : if the teacher and the student have the same feature dimension, it is optional; otherwise it is required. It is the mapping function to match teacher and student intermediate feature dimension. It is a list, with these elements:
proj[0] (str): mapping function, can be
'linear'
,'relu'
,'tanh'
. SeePROJ_MAP
.proj[1] (int): feature dimension of student model.
proj[2] (int): feature dimension of teacher model.
proj[3] (dict): optional, provides configurations such as learning rate. If not provided, the learning rate and optimizer configurations will follow the default config of the optimizer, otherwise it will use the ones specified here.
Example:
from textbrewer import DistillationConfig # simple configuration: use default values, or try different temperatures distill_config = DistillationConfig(temperature=8) # adding intermediate feature matching # under this setting, the returned dict results_T/S of adaptor_T/S should contain 'hidden' key. # The mse loss between teacher's results_T['hidden'][10] and student's results_S['hidden'][3] will be computed distill_config = DistillationConfig( temperature=8, intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden', 'loss':'hidden_mse', 'weight':1}] ) # multiple inatermediate feature matching. The teacher and the student have a hidden_dim of 768 and 384 respectively. distill_config = DistillationConfig( temperature = 8, intermediate_matches = [ \ {'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':4, 'layer_S':1, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':8, 'layer_S':2, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':12, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}] )
-
classmethod
from_dict
(dict_object)¶ Construct configurations from a dict.
-
classmethod
from_json_file
(json_filename)¶ Construct configurations from a json file.