Intermediate Losses¶
Here we list the definitions of predefined intermediate losses.
Usually, users don’t need to refer to these functions directly, but refer to them by the names in MATCH_LOSS_MAP
.
attention_mse¶

textbrewer.losses.
att_mse_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the inputs_mask is given, masks the positions where
input_mask==0
.
 Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_mse_sum¶

textbrewer.losses.
att_mse_sum_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the the shape is (batch_size, num_heads, length, length), sums along the num_heads dimension and then calcuates the mse loss between the two matrices.
If the inputs_mask is given, masks the positions where
input_mask==0
.
 Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce¶

textbrewer.losses.
att_ce_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the crossentropy loss between attention_S and attention_T, where softmax is to applied on
dim=1
.If the inputs_mask is given, masks the positions where
input_mask==0
.
 Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce_mean¶

textbrewer.losses.
att_ce_mean_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the crossentropy loss between attention_S and attention_T, where softmax is to applied on
dim=1
.If the shape is (batch_size, num_heads, length, length), averages over dimension num_heads and then computes crossentropy loss between the two matrics.
If the inputs_mask is given, masks the positions where
input_mask==0
.
 Parameters
logits_S (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.tensor) – tensor of shape (batch_size, length)
cos¶

textbrewer.losses.
cos_loss
(state_S, state_T, mask=None)[source]¶ Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see DistilBERT
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
 Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
mask (torch.Tensor) – tensor of shape (batch_size, length)
pkd¶

textbrewer.losses.
pkd_loss
(state_S, state_T, mask=None)[source]¶ Computes normalized vector mse loss at position 0 along length dimension. This is the loss used in BERTPKD, see Patient Knowledge Distillation for BERT Model Compression.
If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
If the input tensors are of shape (batch_size, hidden_size), it directly computes the loss between tensors without taking the hidden states at position 0.
 Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
mask – not used.
nst (mmd)¶

textbrewer.losses.
mmd_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrices state_S and state_T. Each list contains 2 matrices of the shape (batch_size, length, hidden_size). hidden_size of matrices in State_S doesn’t need to be the same as that of state_T. Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, length, length) ) and the ones in B ( with the resulting shape (batch_size, length, length) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1} \cdot S_{2}^T  T_{1} \cdot T_{2}^T)^2)\]It is a Variant of the NST loss in Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
If the inputs_mask is given, masks the positions where
input_mask==0
.
 Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1}, ...]
fsp (gram)¶

textbrewer.losses.
fsp_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrics state_S and state_T. Each list contains two matrices of the shape (batch_size, length, hidden_size). Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, hidden_size, hidden_size) ) and the ones in B ( with the resulting shape (batch_size, hidden_size, hidden_size) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1}^T \cdot S_{2}  T_{1}^T \cdot T_{2})^2)\]It is a Variant of FSP loss in A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning.
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
 Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]}, ...]