import torch.nn.functional as F
import torch
from typing import List
from .compatibility import mask_dtype
def kd_mse_loss(logits_S, logits_T, temperature=1):
'''
Calculate the mse loss between logits_S and logits_T
:param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
'''
if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
temperature = temperature.unsqueeze(-1)
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
loss = F.mse_loss(beta_logits_S, beta_logits_T)
return loss
def kd_ce_loss(logits_S, logits_T, temperature=1):
'''
Calculate the cross entropy between logits_S and logits_T
:param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
'''
if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
temperature = temperature.unsqueeze(-1)
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
p_T = F.softmax(beta_logits_T, dim=-1)
loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
return loss
[docs]def att_mse_loss(attention_S, attention_T, mask=None):
'''
* Calculates the mse loss between `attention_S` and `attention_T`.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''
if mask is None:
attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
loss = F.mse_loss(attention_S_select, attention_T_select)
else:
mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
valid_count = torch.pow(mask.sum(dim=2),2).sum()
loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(2)).sum() / valid_count
return loss
[docs]def att_mse_sum_loss(attention_S, attention_T, mask=None):
'''
* Calculates the mse loss between `attention_S` and `attention_T`.
* If the the shape is (*batch_size*, *num_heads*, *length*, *length*), sums along the `num_heads` dimension and then calcuates the mse loss between the two matrices.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
:param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''
if len(attention_S.size())==4:
attention_T = attention_T.sum(dim=1)
attention_S = attention_S.sum(dim=1)
if mask is None:
attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
loss = F.mse_loss(attention_S_select, attention_T_select)
else:
mask = mask.to(attention_S)
valid_count = torch.pow(mask.sum(dim=1), 2).sum()
loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(1)).sum() / valid_count
return loss
[docs]def att_ce_loss(attention_S, attention_T, mask=None):
'''
* Calculates the cross-entropy loss between `attention_S` and `attention_T`, where softmax is to applied on ``dim=-1``.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''
probs_T = F.softmax(attention_T, dim=-1)
if mask is None:
probs_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), probs_T)
loss = -((probs_T_select * F.log_softmax(attention_S, dim=-1)).sum(dim=-1)).mean()
else:
mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
loss = -((probs_T * F.log_softmax(attention_S, dim=-1) * mask.unsqueeze(2)).sum(dim=-1) * mask).sum() / mask.sum()
return loss
[docs]def att_ce_mean_loss(attention_S, attention_T, mask=None):
'''
* Calculates the cross-entropy loss between `attention_S` and `attention_T`, where softmax is to applied on ``dim=-1``.
* If the shape is (*batch_size*, *num_heads*, *length*, *length*), averages over dimension `num_heads` and then computes cross-entropy loss between the two matrics.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
:param torch.tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
:param torch.tensor mask: tensor of shape (*batch_size*, *length*)
'''
if len(attention_S.size())==4:
attention_S = attention_S.mean(dim=1) # (bs, len, len)
attention_T = attention_T.mean(dim=1)
probs_T = F.softmax(attention_T, dim=-1)
if mask is None:
probs_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), probs_T)
loss = -((probs_T_select * F.log_softmax(attention_S, dim=-1)).sum(dim=-1)).mean()
else:
mask = mask.to(attention_S)
loss = -((probs_T * F.log_softmax(attention_S, dim=-1) * mask.unsqueeze(1)).sum(dim=-1) * mask).sum() / mask.sum()
return loss
[docs]def hid_mse_loss(state_S, state_T, mask=None):
'''
* Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
* If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
:param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*)
:param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''
if mask is None:
loss = F.mse_loss(state_S, state_T)
else:
mask = mask.to(state_S)
valid_count = mask.sum() * state_S.size(-1)
loss = (F.mse_loss(state_S, state_T, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count
return loss
[docs]def cos_loss(state_S, state_T, mask=None):
'''
* Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see `DistilBERT <https://arxiv.org/abs/1910.01108>`_
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
* If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
:param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*)
:param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''
if mask is None:
state_S = state_S.view(-1,state_S.size(-1))
state_T = state_T.view(-1,state_T.size(-1))
else:
mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype) #(bs,len,dim)
state_S = torch.masked_select(state_S, mask).view(-1, mask.size(-1)) #(bs * select, dim)
state_T = torch.masked_select(state_T, mask).view(-1, mask.size(-1)) # (bs * select, dim)
target = state_S.new(state_S.size(0)).fill_(1)
loss = F.cosine_embedding_loss(state_S, state_T, target, reduction='mean')
return loss
[docs]def pkd_loss(state_S, state_T, mask=None):
'''
* Computes normalized vector mse loss at position 0 along `length` dimension. This is the loss used in BERT-PKD, see `Patient Knowledge Distillation for BERT Model Compression <https://arxiv.org/abs/1908.09355>`_.
* If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
* If the input tensors are of shape (*batch_size*, *hidden_size*), it directly computes the loss between tensors without taking the hidden states at position 0.
:param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*) or (*batch_size*, *hidden_size*)
:param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*) or (*batch_size*, *hidden_size*)
:param mask: not used.
'''
if state_T.dim()==3:
cls_T = state_T[:,0] # (batch_size, hidden_dim)
else:
cls_T = state_T
if state_S.dim()==3:
cls_S = state_S[:,0] # (batch_size, hidden_dim)
else:
cls_S = state_S
normed_cls_T = cls_T/torch.norm(cls_T,dim=1,keepdim=True)
normed_cls_S = cls_S/torch.norm(cls_S,dim=1,keepdim=True)
loss = (normed_cls_S - normed_cls_T).pow(2).sum(dim=-1).mean()
return loss
[docs]def fsp_loss(state_S, state_T, mask=None):
'''
* Takes in two lists of matrics `state_S` and `state_T`. Each list contains two matrices of the shape (*batch_size*, *length*, *hidden_size*). Computes the similarity matrix between the two matrices in `state_S` ( with the resulting shape (*batch_size*, *hidden_size*, *hidden_size*) ) and the ones in B ( with the resulting shape (*batch_size*, *hidden_size*, *hidden_size*) ), then computes the mse loss between the similarity matrices:
.. math::
loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)
* It is a Variant of FSP loss in `A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning <http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf>`_.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
* If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
:param torch.tensor state_S: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
:param torch.tensor state_T: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
:param torch.tensor mask: tensor of the shape (*batch_size*, *length*)
Example in `intermediate_matches`::
intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]},
...]
'''
if mask is None:
state_S_0 = state_S[0] # (batch_size , length, hidden_dim)
state_S_1 = state_S[1] # (batch_size, length, hidden_dim)
state_T_0 = state_T[0]
state_T_1 = state_T[1]
gram_S = torch.bmm(state_S_0.transpose(1, 2), state_S_1) / state_S_1.size(1) # (batch_size, hidden_dim, hidden_dim)
gram_T = torch.bmm(state_T_0.transpose(1, 2), state_T_1) / state_T_1.size(1)
else:
mask = mask.to(state_S[0]).unsqueeze(-1)
lengths = mask.sum(dim=1,keepdim=True)
state_S_0 = state_S[0] * mask
state_S_1 = state_S[1] * mask
state_T_0 = state_T[0] * mask
state_T_1 = state_T[1] * mask
gram_S = torch.bmm(state_S_0.transpose(1,2), state_S_1)/lengths
gram_T = torch.bmm(state_T_0.transpose(1,2), state_T_1)/lengths
loss = F.mse_loss(gram_S, gram_T)
return loss
[docs]def mmd_loss(state_S, state_T, mask=None):
'''
* Takes in two lists of matrices `state_S` and `state_T`. Each list contains 2 matrices of the shape (*batch_size*, *length*, *hidden_size*). `hidden_size` of matrices in `State_S` doesn't need to be the same as that of `state_T`. Computes the similarity matrix between the two matrices in `state_S` ( with the resulting shape (*batch_size*, *length*, *length*) ) and the ones in B ( with the resulting shape (*batch_size*, *length*, *length*) ), then computes the mse loss between the similarity matrices:
.. math::
loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)
* It is a Variant of the NST loss in `Like What You Like: Knowledge Distill via Neuron Selectivity Transfer <https://arxiv.org/abs/1707.01219>`_
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.tensor state_S: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
:param torch.tensor state_T: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
:param torch.tensor mask: tensor of the shape (*batch_size*, *length*)
Example in `intermediate_matches`::
intermediate_matches = [
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1},
...]
'''
state_S_0 = state_S[0] # (batch_size , length, hidden_dim_S)
state_S_1 = state_S[1] # (batch_size , length, hidden_dim_S)
state_T_0 = state_T[0] # (batch_size , length, hidden_dim_T)
state_T_1 = state_T[1] # (batch_size , length, hidden_dim_T)
if mask is None:
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
loss = F.mse_loss(gram_S, gram_T)
else:
mask = mask.to(state_S[0])
valid_count = torch.pow(mask.sum(dim=1), 2).sum()
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
loss = (F.mse_loss(gram_S, gram_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(1)).sum() / valid_count
return loss