Intermediate Losses¶
Here we list the definitions of pre-defined intermediate losses.
Usually, users don’t need to refer to these functions directly, but refer to them by the names in MATCH_LOSS_MAP
.
attention_mse¶
-
textbrewer.losses.
att_mse_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_mse_sum¶
-
textbrewer.losses.
att_mse_sum_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the mse loss between attention_S and attention_T.
If the the shape is (batch_size, num_heads, length, length), sums along the num_heads dimension and then calcuates the mse loss between the two matrices.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce¶
-
textbrewer.losses.
att_ce_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on
dim=-1
.If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)
mask (torch.Tensor) – tensor of shape (batch_size, length)
attention_ce_mean¶
-
textbrewer.losses.
att_ce_mean_loss
(attention_S, attention_T, mask=None)[source]¶ Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on
dim=-1
.If the shape is (batch_size, num_heads, length, length), averages over dimension num_heads and then computes cross-entropy loss between the two matrics.
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
logits_S (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
logits_T (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)
mask (torch.tensor) – tensor of shape (batch_size, length)
cos¶
-
textbrewer.losses.
cos_loss
(state_S, state_T, mask=None)[source]¶ Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see DistilBERT
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
- Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)
mask (torch.Tensor) – tensor of shape (batch_size, length)
pkd¶
-
textbrewer.losses.
pkd_loss
(state_S, state_T, mask=None)[source]¶ Computes normalized vector mse loss at position 0 along length dimension. This is the loss used in BERT-PKD, see Patient Knowledge Distillation for BERT Model Compression.
If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
If the input tensors are of shape (batch_size, hidden_size), it directly computes the loss between tensors without taking the hidden states at position 0.
- Parameters
state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)
mask – not used.
nst (mmd)¶
-
textbrewer.losses.
mmd_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrices state_S and state_T. Each list contains 2 matrices of the shape (batch_size, length, hidden_size). hidden_size of matrices in State_S doesn’t need to be the same as that of state_T. Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, length, length) ) and the ones in B ( with the resulting shape (batch_size, length, length) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)\]It is a Variant of the NST loss in Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
If the inputs_mask is given, masks the positions where
input_mask==0
.
- Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1}, ...]
fsp (gram)¶
-
textbrewer.losses.
fsp_loss
(state_S, state_T, mask=None)[source]¶ Takes in two lists of matrics state_S and state_T. Each list contains two matrices of the shape (batch_size, length, hidden_size). Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, hidden_size, hidden_size) ) and the ones in B ( with the resulting shape (batch_size, hidden_size, hidden_size) ), then computes the mse loss between the similarity matrices:
\[loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)\]It is a Variant of FSP loss in A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning.
If the inputs_mask is given, masks the positions where
input_mask==0
.If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.
- Parameters
state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)
mask (torch.tensor) – tensor of the shape (batch_size, length)
Example in intermediate_matches:
intermediate_matches = [ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]}, ...]