from .distiller_utils import *
from .distiller_basic import BasicDistiller
[docs]class GeneralDistiller(BasicDistiller):
"""
Supports intermediate features matching. **Recommended for single-teacher single-task distillation**.
Args:
train_config (:class:`TrainingConfig`): training configuration.
distill_config (:class:`DistillationConfig`): distillation configuration.
model_T (:class:`torch.nn.Module`): teacher model.
model_S (:class:`torch.nn.Module`): student model.
adaptor_T (Callable): teacher model's adaptor.
adaptor_S (Callable): student model's adaptor.
custom_matches (list): supports more flexible user-defined matches (testing).
The roles of `adaptor_T` and `adaptor_S` are explained in :py:func:`adaptor`.
"""
def __init__(self, train_config,
distill_config,
model_T,
model_S,
adaptor_T,
adaptor_S,
custom_matches: Optional[List[CustomMatch]] = None):
# custom_matches=[{'module_T': module_T, 'module_S':module_S,
# 'loss': loss, 'weight': weight},...]
super(GeneralDistiller, self).__init__(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)
self.projs = []
self.projs_group = []
for im in self.d_config.intermediate_matches:
if im.proj is not None:
projection = im.proj[0]
dim_in = im.proj[1]
dim_out = im.proj[2]
self.projs_group.append(im.proj[3])
self.projs.append(PROJ_MAP[projection](dim_in,dim_out))
self.projs[-1].to(self.t_config.device)
else:
self.projs.append(None)
self.projs_group.append(None)
self.has_custom_matches = False
if custom_matches:
self.handles_T = []
self.handles_S = []
self.custom_matches_cache = {'hook_outputs_T': [], 'hook_outputs_S': [], 'match_proj_funcs': [],
'match_weights': [], 'match_losses': [], 'match_proj_groups': []}
for match in custom_matches:
self.add_match(match)
self.has_custom_matches = True
self.d_config.is_caching_logits = False
def save_and_callback(self,global_step, step, epoch, callback):
if self.has_custom_matches:
handles_T = self.model_T._forward_hooks
handles_S = self.model_S._forward_hooks
self.model_S._forward_hooks = OrderedDict() # clear hooks
self.model_T._forward_hooks = OrderedDict()
super(GeneralDistiller, self).save_and_callback(global_step, step, epoch, callback)
if self.has_custom_matches:
self.model_S._forward_hooks = handles_S # restore hooks
self.model_T._forward_hooks = handles_T
def train_on_batch(self, batch, args):
(teacher_batch, results_T), (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args)
results_T = post_adaptor(self.adaptor_T(teacher_batch,results_T))
results_S = post_adaptor(self.adaptor_S(student_batch,results_S))
total_loss, losses_dict = self.compute_loss(results_S, results_T)
return total_loss, losses_dict
def compute_loss(self,results_S,results_T):
losses_dict = dict()
total_loss = 0
if 'logits' in results_T and 'logits' in results_S:
logits_list_T = results_T['logits'] # list of tensor
logits_list_S = results_S['logits'] # list of tensor
total_kd_loss = 0
if 'logits_mask' in results_S:
masks_list_S = results_S['logits_mask']
logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S) #(mask_sum, num_of_class)
if 'logits_mask' in results_T:
masks_list_T = results_T['logits_mask']
logits_list_T = select_logits_with_mask(logits_list_T,masks_list_T) #(mask_sum, num_of_class)
if self.d_config.probability_shift is True:
labels_list = results_S['labels']
for l_T, l_S, labels in zip(logits_list_T, logits_list_S, labels_list):
l_T = probability_shift_(l_T, labels)
if self.d_config.temperature_scheduler is not None:
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
else:
for l_T,l_S in zip(logits_list_T,logits_list_S):
if self.d_config.temperature_scheduler is not None:
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
total_loss += total_kd_loss * self.d_config.kd_loss_weight
losses_dict['unweighted_kd_loss'] = total_kd_loss
inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}
inputs_mask_T = results_T.get('inputs_mask',None)
inputs_mask_S = results_S.get('inputs_mask',None)
for ith,inter_match in enumerate(self.d_config.intermediate_matches):
layer_T = inter_match.layer_T
layer_S = inter_match.layer_S
feature = inter_match.feature
loss_type = inter_match.loss
match_weight = inter_match.weight
match_loss = MATCH_LOSS_MAP[loss_type]
if type(layer_S) is list and type(layer_T) is list:
inter_S = [inters_S[feature][s] for s in layer_S]
inter_T = [inters_T[feature][t] for t in layer_T]
name_S = '-'.join(map(str,layer_S))
name_T = '-'.join(map(str,layer_T))
if self.projs[ith]:
#inter_T = [self.projs[ith](t) for t in inter_T]
inter_S = [self.projs[ith](s) for s in inter_S]
else:
inter_S = inters_S[feature][layer_S]
inter_T = inters_T[feature][layer_T]
name_S = str(layer_S)
name_T = str(layer_T)
if self.projs[ith]:
#inter_T = self.projs[ith](inter_T)
inter_S = self.projs[ith](inter_S)
intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S)
total_loss += intermediate_loss * match_weight
losses_dict[f'unweighted_{feature}_{loss_type}_{name_S}_{name_T}'] = intermediate_loss
if self.has_custom_matches:
for hook_T, hook_S, match_weight, match_loss, proj_func in \
zip(self.custom_matches_cache['hook_outputs_T'], self.custom_matches_cache['hook_outputs_S'],
self.custom_matches_cache['match_weghts'], self.custom_matches_cache['match_losses'],
self.custom_matches_cache['match_proj_funcs']):
if proj_func is not None:
hook_S = proj_func(hook_S)
total_loss += match_weight * match_loss(hook_S,hook_T,inputs_mask_S,inputs_mask_T)
self.custom_matches_cache['hook_outputs_T'] = []
self.custom_matches_cache['hook_outputs_S'] = []
if 'losses' in results_S:
total_hl_loss = 0
for loss in results_S['losses']:
# in case of multi-GPU
total_hl_loss += loss.mean()
total_loss += total_hl_loss * self.d_config.hard_label_weight
losses_dict['unweighted_hard_label_loss'] = total_hl_loss
return total_loss, losses_dict
def add_match(self,match: CustomMatch):
if type(match.module_T) is str or type(match.module_S) is str:
raise NotImplementedError
else:
module_T = match.module_T
module_S = match.module_S
weight = match.weight
loss = match.loss
proj_func = match.proj_func
proj_group = match.proj_group
self.add_match_by_module(module_T,module_S,proj_func,proj_group,weight,loss)
def add_match_by_module(self,module_T : torch.nn.Module,
module_S : torch.nn.Module,
proj_func, proj_group,
match_weight, match_loss):
self.handles_T = module_T.register_forward_hook(self._hook_T)
self.handles_S = module_S.register_forward_hook(self._hook_S)
self.custom_matches_cache['match_weights'].append(match_weight)
self.custom_matches_cache['match_losses'].append(match_loss)
self.custom_matches_cache['match_proj_funcs'].append(proj_func)
if isinstance(proj_func,nn.Module):
self.custom_matches_cache['match_proj_funcs'][-1].to(self.t_config.device)
self.custom_matches_cache['match_proj_groups'].append(proj_group)
def _hook_T(self,module,input, output):
self.custom_matches_cache['hook_outputs_T'].append(output)
def _hook_S(self, module, input, output):
self.custom_matches_cache['hook_outputs_S'].append(output)