Presets¶
Presets include module variables that define pre-defined loss functions and strategies.
Module variables¶
ADAPTOR_KEYS¶
-
textbrewer.presets.
ADAPTOR_KEYS
¶ (list) valid keys of the dict returned by the adaptor, includes:
‘logits’
‘logits_mask’
‘losses’
‘inputs_mask’
‘labels’
‘hidden’
‘attention’
KD_LOSS_MAP¶
-
textbrewer.presets.
KD_LOSS_MAP
¶ (dict) available KD losses
‘mse’ : mean squared error
‘ce’: cross-entropy loss
PROJ_MAP¶
-
textbrewer.presets.
PROJ_MAP
¶ (dict) layers used to match the different dimensions of intermediate features
‘linear’ : linear layer, no activation
‘relu’ : ReLU activation
‘tanh’: Tanh activation
MATCH_LOSS_MAP¶
-
textbrewer.presets.
MATCH_LOSS_MAP
¶ (dict) intermediate feature matching loss functions, includes:
hidden_mse
nst
,mmd
See Intermediate Losses for details.
WEIGHT_SCHEDULER¶
-
textbrewer.presets.
WEIGHT_SCHEDULER
¶ (dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight.
‘linear_decay’ : decay from 1 to 0 during the whole training process.
‘linear_growth’ : grow from 0 to 1 during the whole training process.
TEMPERATURE_SCHEDULER¶
-
textbrewer.presets.
TEMPERATURE_SCHEDULER
¶ (custom dict) used to dynamically adjust distillation temperature.
‘constant’ : Constant temperature.
‘flsw’ : See Preparing Lessons: Improve Knowledge Distillation with Better Supervision. Needs parameters
beta
andgamma
.‘cwsm’: See Preparing Lessons: Improve Knowledge Distillation with Better Supervision. Needs parameter
beta
.
Different from other options, when using
'flsw'
and'cwsm'
, you need to provide extra parameters, for example:#flsw distill_config = DistillationConfig( temperature_scheduler = ['flsw', 1, 2] # beta=1, gamma=2 ) #cwsm distill_config = DistillationConfig( temperature_scheduler = ['cwsm', 1] # beta = 1 )
Customization¶
If the pre-defined modules do not satisfy your requirements, you can add your own defined modules to the above dict.
For example:
MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss
WEIGHT_SCHEDULER['my_weight_scheduler'] = my_weight_scheduler
then used in DistillationConfig
:
distill_config = DistillationConfig(
kd_loss_weight_scheduler = 'my_weight_scheduler'
intermediate_matches = [{'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'my_L1_loss', 'weight' : 1}]
...)
Refer to the source code for more details on inputs and outputs conventions (will be explained in detail in a later version of the documentation).