Intermediate Losses

Here we list the definitions of pre-defined intermediate losses. Usually, users don’t need to refer to these functions directly, but refer to them by the names in MATCH_LOSS_MAP.

attention_mse

textbrewer.losses.att_mse_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the mse loss between attention_S and attention_T.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_mse_sum

textbrewer.losses.att_mse_sum_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the mse loss between attention_S and attention_T.

  • If the the shape is (batch_size, num_heads, length, length), sums along the num_heads dimension and then calcuates the mse loss between the two matrices.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_ce

textbrewer.losses.att_ce_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on dim=-1.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • logits_T (torch.Tensor) – tensor of shape (batch_size, num_heads, length, length)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

attention_ce_mean

textbrewer.losses.att_ce_mean_loss(attention_S, attention_T, mask=None)[source]
  • Calculates the cross-entropy loss between attention_S and attention_T, where softmax is to applied on dim=-1.

  • If the shape is (batch_size, num_heads, length, length), averages over dimension num_heads and then computes cross-entropy loss between the two matrics.

  • If the inputs_mask is given, masks the positions where input_mask==0.

Parameters
  • logits_S (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • logits_T (torch.tensor) – tensor of shape (batch_size, num_heads, length, length) or (batch_size, length, length)

  • mask (torch.tensor) – tensor of shape (batch_size, length)

hidden_mse

textbrewer.losses.hid_mse_loss(state_S, state_T, mask=None)[source]
  • Calculates the mse loss between state_S and state_T, which are the hidden state of the models.

  • If the inputs_mask is given, masks the positions where input_mask==0.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

cos

textbrewer.losses.cos_loss(state_S, state_T, mask=None)[source]
  • Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see DistilBERT

  • If the inputs_mask is given, masks the positions where input_mask==0.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size)

  • mask (torch.Tensor) – tensor of shape (batch_size, length)

pkd

textbrewer.losses.pkd_loss(state_S, state_T, mask=None)[source]
  • Computes normalized vector mse loss at position 0 along length dimension. This is the loss used in BERT-PKD, see Patient Knowledge Distillation for BERT Model Compression.

  • If the hidden sizes of student and teacher are different, ‘proj’ option is required in inetermediate_matches to match the dimensions.

  • If the input tensors are of shape (batch_size, hidden_size), it directly computes the loss between tensors without taking the hidden states at position 0.

Parameters
  • state_S (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)

  • state_T (torch.Tensor) – tensor of shape (batch_size, length, hidden_size) or (batch_size, hidden_size)

  • mask – not used.

nst (mmd)

textbrewer.losses.mmd_loss(state_S, state_T, mask=None)[source]
  • Takes in two lists of matrices state_S and state_T. Each list contains 2 matrices of the shape (batch_size, length, hidden_size). hidden_size of matrices in State_S doesn’t need to be the same as that of state_T. Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, length, length) ) and the ones in B ( with the resulting shape (batch_size, length, length) ), then computes the mse loss between the similarity matrices:

\[loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)\]
Parameters
  • state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • mask (torch.tensor) – tensor of the shape (batch_size, length)

Example in intermediate_matches:

intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1},
...]

fsp (gram)

textbrewer.losses.fsp_loss(state_S, state_T, mask=None)[source]
  • Takes in two lists of matrics state_S and state_T. Each list contains two matrices of the shape (batch_size, length, hidden_size). Computes the similarity matrix between the two matrices in state_S ( with the resulting shape (batch_size, hidden_size, hidden_size) ) and the ones in B ( with the resulting shape (batch_size, hidden_size, hidden_size) ), then computes the mse loss between the similarity matrices:

\[loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)\]
Parameters
  • state_S (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • state_T (torch.tensor) – list of two tensors, each tensor is of the shape (batch_size, length, hidden_size)

  • mask (torch.tensor) – tensor of the shape (batch_size, length)

Example in intermediate_matches:

intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]},
...]