Presets

Presets include module variables that define pre-defined loss functions and strategies.

Module variables

ADAPTOR_KEYS

textbrewer.presets.ADAPTOR_KEYS

(list) valid keys of the dict returned by the adaptor, includes:

  • logits

  • logits_mask

  • losses

  • inputs_mask

  • labels

  • hidden

  • attention

KD_LOSS_MAP

textbrewer.presets.KD_LOSS_MAP

(dict) available KD losses

  • mse’ : mean squared error

  • ce’: cross-entropy loss

PROJ_MAP

textbrewer.presets.PROJ_MAP

(dict) layers used to match the different dimensions of intermediate features

  • linear’ : linear layer, no activation

  • relu’ : ReLU activation

  • tanh’: Tanh activation

MATCH_LOSS_MAP

textbrewer.presets.MATCH_LOSS_MAP

(dict) intermediate feature matching loss functions, includes:

See Intermediate Losses for details.

WEIGHT_SCHEDULER

textbrewer.presets.WEIGHT_SCHEDULER

(dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight.

  • linear_decay’ : decay from 1 to 0 during the whole training process.

  • linear_growth’ : grow from 0 to 1 during the whole training process.

TEMPERATURE_SCHEDULER

textbrewer.presets.TEMPERATURE_SCHEDULER

(custom dict) used to dynamically adjust distillation temperature.

Different from other options, when using 'flsw' and 'cwsm', you need to provide extra parameters, for example:

#flsw
distill_config = DistillationConfig(
    temperature_scheduler = ['flsw', 1, 2]  # beta=1, gamma=2
)

#cwsm
distill_config = DistillationConfig(
    temperature_scheduler = ['cwsm', 1] # beta = 1
)

Customization

If the pre-defined modules do not satisfy your requirements, you can add your own defined modules to the above dict.

For example:

MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss
WEIGHT_SCHEDULER['my_weight_scheduler'] = my_weight_scheduler

then used in DistillationConfig:

distill_config = DistillationConfig(
kd_loss_weight_scheduler = 'my_weight_scheduler'
intermediate_matches = [{'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'my_L1_loss', 'weight' : 1}]
...)

Refer to the source code for more details on inputs and outputs conventions (will be explained in detail in a later version of the documentation).