Source code for textbrewer.distiller_basic

from .distiller_utils import *

[docs]class BasicDistiller(AbstractDistiller): """ Performs **single-teacher single-task** distillation, provides basic distillation strategies. Args: train_config (:class:`TrainingConfig`): training configuration. distill_config (:class:`DistillationConfig`): distillation configuration. model_T (:class:`torch.nn.Module`): teacher model. model_S (:class:`torch.nn.Module`): student model. adaptor_T (Callable): teacher model's adaptor. adaptor_S (Callable): student model's adaptor. The roles of `adaptor_T` and `adaptor_S` are explained in :py:func:`adaptor`. """ def __init__(self, train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S): super(BasicDistiller, self).__init__(train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S) def save_and_callback(self,global_step, step, epoch, callback): if self.rank != 0: torch.distributed.barrier() # save and eval with single process else: logger.info(f"Saving at global step {global_step}, epoch step {step + 1} epoch {epoch+1}") coreModel = self.model_S.module if hasattr(self.model_S, "module") else self.model_S state_dict = coreModel.state_dict() torch.save(state_dict, os.path.join(self.t_config.output_dir, f"gs{global_step}.pkl")) if self.local_rank == 0: torch.distributed.barrier() if callback is not None: logger.info("Running callback function...") callback(model=self.model_S, step=global_step) self.model_S.train() def write_loss(self, total_loss, writer_step, losses_dict=None): if self.rank == 0: cpu_total_loss = total_loss.cpu().item() self.tb_writer.add_scalar('scalar/total_loss', cpu_total_loss, writer_step) if losses_dict is not None: for name, loss in losses_dict.items(): cpu_loss = loss.cpu().item() self.tb_writer.add_scalar(f"scalar/{name}", cpu_loss, writer_step) def initialize_training(self, optimizer, scheduler_class, scheduler_args, scheduler): # update optimizer for projection layer (used in GeneralDistiller) if hasattr(self,'projs'): for proj,proj_group in zip(self.projs, self.projs_group): if proj is not None: assert isinstance(proj,nn.Module) optimizer.add_param_group({**{'params':proj.parameters()},**proj_group}) if hasattr(self,'has_custom_matches') and self.has_custom_matches: for proj_func,proj_group in zip(self.custom_matches_cache['match_proj_funcs'], self.custom_matches_cache['match_proj_groups']): if isinstance(proj_func,nn.Module): optimizer.add_param_group({**{'params':proj_func.parameters()},**proj_group}) logger.debug("Optimizer param group: ") logger.debug(f"{[[s.shape for s in g['params']] for g in optimizer.param_groups]}") # update scheduler if scheduler_class is not None: # overwrite scheduler scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args) if self.t_config.fp16: if not has_apex: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") if isinstance(self.model_T,(list,tuple)): models = [self.model_S] + list(self.model_T) models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level) self.model_S = models[0] self.model_T =models[1:] elif isinstance(self.model_T,dict): tasknames, model_Ts = zip(*self.model_T.items()) models = [self.model_S] + list(model_Ts) models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level) self.model_S = models[0] self.model_T = dict(zip(tasknames,models[1:])) else: (self.model_S, self.model_T), optimizer = amp.initialize([self.model_S, self.model_T], optimizer, opt_level=self.t_config.fp16_opt_level) if self.local_rank != -1: self.model_S = torch.nn.parallel.DistributedDataParallel(self.model_S, device_ids = [self.local_rank], output_device = self.local_rank, find_unused_parameters = True) if isinstance(self.model_T,(list,tuple)): self.model_T = [torch.nn.parallel.DistributedDataParallel(model_t, device_ids = [self.local_rank], output_device = self.local_rank, find_unused_parameters = True) for model_t in self.model_T] elif isinstance(self.model_T,dict): self.model_T = {k:torch.nn.parallel.DistributedDataParallel(v, device_ids = [self.local_rank], output_device = self.local_rank, find_unused_parameters = True) for k,v in self.model_T.items()} else: self.model_T = torch.nn.parallel.DistributedDataParallel(self.model_T, device_ids = [self.local_rank], output_device = self.local_rank, find_unused_parameters = True) if hasattr(self,'projs'): for i,proj in enumerate(self.projs): if proj is not None: assert isinstance(proj,nn.Module) self.projs[i] = torch.nn.parallel.DistributedDataParallel(proj, device_ids = [self.local_rank], output_device = self.local_rank) elif self.t_config.data_parallel: self.model_S = torch.nn.DataParallel(self.model_S) if isinstance(self.model_T,(list,tuple)): self.model_T = [torch.nn.DataParallel(model_t) for model_t in self.model_T] elif isinstance(self.model_T,dict): self.model_T = {k:torch.nn.DataParallel(v) for k,v in self.model_T.items()} else: self.model_T = torch.nn.DataParallel(self.model_T) tqdm_disable = None if self.rank == 0 else True return optimizer, scheduler, tqdm_disable def train_with_num_steps(self, optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_steps, callback, batch_postprocessor, **args): if self.d_config.is_caching_logits is True: raise AssertionError("You cannot set is_caching_logits to True with num_steps not None!") total_global_steps = num_steps ckpt_steps = int(self.t_config.ckpt_steps) num_steps = int(num_steps) print_every = ckpt_steps // self.print_freq if print_every == 0: print_every = ckpt_steps checkpoints = [ i * ckpt_steps for i in range(1,num_steps//ckpt_steps+1)] + [total_global_steps] logger.info(f"Total training steps: {total_global_steps}") logger.info(f"Checkpoints(step): {checkpoints}") global_step = 0 writer_step = 0 for step, batch in tqdm(enumerate(cycle(dataloader)),disable=tqdm_disable): if batch_postprocessor is not None: batch = batch_postprocessor(batch) total_loss, losses_dict = self.train_on_batch(batch,args) self.write_loss(total_loss, writer_step, losses_dict) writer_step += 1 total_loss /= self.t_config.gradient_accumulation_steps if self.t_config.fp16: with amp.scale_loss(total_loss,optimizer) as scaled_loss: scaled_loss.backward() else: total_loss.backward() if (step+1)%self.t_config.gradient_accumulation_steps == 0: if max_grad_norm > 0: if self.t_config.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm) optimizer.step() if scheduler is not None: scheduler.step() optimizer.zero_grad() global_step += 1 if self.d_config.kd_loss_weight_scheduler is not None: self.d_config.kd_loss_weight = \ self.d_config.kd_loss_weight_scheduler(global_step/total_global_steps) if self.d_config.hard_label_weight_scheduler is not None: self.d_config.hard_label_weight = \ self.d_config.hard_label_weight_scheduler(global_step/total_global_steps) if (global_step) % print_every == 0: logger.info(f"Global step: {global_step}, epoch step:{step+1}") if (global_step%ckpt_steps==0) or global_step==total_global_steps: self.save_and_callback(global_step, step, 0, callback) if global_step >= total_global_steps: logger.info("Training finished") return def train_with_num_epochs(self, optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, **args): train_steps_per_epoch = len(dataloader)//self.t_config.gradient_accumulation_steps total_global_steps = train_steps_per_epoch * num_epochs print_every = train_steps_per_epoch // self.print_freq if print_every == 0: print_every = train_steps_per_epoch checkpoints = [int(train_steps_per_epoch*ci/self.t_config.ckpt_frequency) for ci in range(self.t_config.ckpt_frequency)] logger.info(f"Training steps per epoch: {train_steps_per_epoch}") logger.info(f"Checkpoints(step): {checkpoints}") global_step = 0 writer_step = 0 if self.d_config.is_caching_logits is True: logger.info(f"Caching batches and teacher's logits...") for step, batch in tqdm(enumerate(dataloader),disable=tqdm_disable): self.cache_logits(batch, args, batch_postprocessor) for current_epoch in tqdm(range(int(num_epochs)),disable=tqdm_disable): if self.local_rank != -1 and hasattr(dataloader,'sampler'): dataloader.sampler.set_epoch(current_epoch) #In distributed mode, calling the set_epoch method is needed to make shuffling work; logger.info(f"Epoch {current_epoch+1}") optimizer.zero_grad() if self.d_config.is_caching_logits: random.shuffle(self.logits_cache) dataloader = self.logits_cache logger.info(f"Length of current epoch in forward batch: {len(dataloader)}") for step, batch in tqdm(enumerate(dataloader),disable=tqdm_disable): if self.d_config.is_caching_logits is False and batch_postprocessor is not None: batch = batch_postprocessor(batch) total_loss, losses_dict = self.train_on_batch(batch,args) self.write_loss(total_loss, writer_step, losses_dict) writer_step += 1 total_loss /= self.t_config.gradient_accumulation_steps if self.t_config.fp16: with amp.scale_loss(total_loss,optimizer) as scaled_loss: scaled_loss.backward() else: total_loss.backward() if (step+1)%self.t_config.gradient_accumulation_steps == 0: if max_grad_norm > 0: if self.t_config.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) else: torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm) optimizer.step() if scheduler is not None: scheduler.step() optimizer.zero_grad() global_step += 1 if self.d_config.kd_loss_weight_scheduler is not None: self.d_config.kd_loss_weight = \ self.d_config.kd_loss_weight_scheduler(global_step/total_global_steps) if self.d_config.hard_label_weight_scheduler is not None: self.d_config.hard_label_weight = \ self.d_config.hard_label_weight_scheduler(global_step/total_global_steps) if (global_step) % print_every == 0: logger.info(f"Global step: {global_step}, epoch step:{step+1}") if (global_step%train_steps_per_epoch in checkpoints) \ and ((current_epoch+1)%self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs): self.save_and_callback(global_step, step, current_epoch, callback) logger.info(f"Epoch {current_epoch+1} finished")
[docs] def train(self, optimizer, dataloader, num_epochs=None, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm = -1.0, num_steps=None, callback=None, batch_postprocessor=None, **args): """ trains the student model. Args: optimizer: optimizer. dataloader: dataset iterator. num_epochs (int): number of training epochs. num_steps (int): number of training steps. If it is not None, distiller will ignore `num_epochs` and trains for `num_steps`, and dataloader can have an unkonwn size, i.e., has no `__len__` attribute. Dataloader will be cycled automatically after iterating over the whole dataset. callback (Callable): function called after each epoch, can be None. It is called as ``callback(model=self.model_S, step = global_step)``. It can be used to evaluate the model at each checkpoint. batch_postprocessor (Callable): a function for post-processing batches. It should take a batch and return a batch. Its output is fed to the models and adaptors. scheduler_class (class): the class of the scheduler to be constructed. scheduler_args (dict): arguments (excluding `optimizer`) passed to the `scheduler_class` to construct the scheduler object. See the example below. scheduler (deprecated): used to adjust learning rate, optional, can be None, is deprecated in favor of `scheduler_class` and `scheduler_args`. max_grad_norm (float): Maximum norm for the gradients (-1 means no clipping). Default: -1.0 **args: additional arguments fed to the model. Note: * If the batch is a list or tuple, model is called as: ``model(*batch, **args)``. Make sure the order of elements in the batch matches their order in ``model.forward``. * If the batch is a dict, model is called as: ``model(**batch,**args)``. Make sure the keys of the batch match the arguments of the ``model.forward``. Note: If you want to provide a lr scheduler, DON'T USE `scheduler` , use `scheduler_class` and `scheduler_args` instead. Example: .. code-block:: from transformers import get_linear_schedule_with_warmup distiller.train(optimizer, scheduler_class = get_linear_schedule_with_warmup, scheduler_args= {'num_warmup_steps': 100, 'num_training_steps': 1000}) """ optimizer, scheduler, tqdm_disable = self.initialize_training(optimizer, scheduler_class, scheduler_args, scheduler) assert not (num_epochs is None and num_steps is None) if num_steps is not None: self.train_with_num_steps(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_steps, callback, batch_postprocessor, **args) else: self.train_with_num_epochs(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, **args)
def train_on_batch(self, batch, args): if self.d_config.is_caching_logits is False: (teacher_batch, results_T), (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args) results_T = post_adaptor(self.adaptor_T(teacher_batch,results_T)) results_S = post_adaptor(self.adaptor_S(student_batch,results_S)) else: batch, cached_logits = batch _, (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args, no_teacher_forward=True) results_S = post_adaptor(self.adaptor_S(student_batch,results_S)) results_T = {'logits':[logits.to(self.t_config.device) for logits in cached_logits]} if 'logits_mask' in results_S: results_T['logits_mask'] = results_S['logits_mask'] total_loss, losses_dict = self.compute_loss(results_S,results_T) return total_loss, losses_dict def compute_loss(self, results_S, results_T): total_loss = 0 losses_dict = dict() logits_list_T = results_T['logits'] # list of tensor logits_list_S = results_S['logits'] # list of tensor if 'logits_mask' in results_S: masks_list_S = results_S['logits_mask'] logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S) #(mask_sum, num_of_class) if 'logits_mask' in results_T: masks_list_T = results_T['logits_mask'] logits_list_T = select_logits_with_mask(logits_list_T,masks_list_T) #(mask_sum, num_of_class) total_kd_loss = 0 if self.d_config.probability_shift is True: labels_list = results_S['labels'] for l_T, l_S, labels in zip(logits_list_T, logits_list_S, labels_list): l_T = probability_shift_(l_T, labels) if self.d_config.temperature_scheduler is not None: temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature) else: temperature = self.d_config.temperature total_kd_loss += self.kd_loss(l_S, l_T, temperature) else: for l_T,l_S in zip(logits_list_T,logits_list_S): if self.d_config.temperature_scheduler is not None: temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature) else: temperature = self.d_config.temperature total_kd_loss += self.kd_loss(l_S, l_T, temperature) total_loss += total_kd_loss * self.d_config.kd_loss_weight losses_dict['unweighted_kd_loss'] = total_kd_loss if 'losses' in results_S: total_hl_loss = 0 for loss in results_S['losses']: # in case of multi-GPU total_hl_loss += loss.mean() total_loss += total_hl_loss * self.d_config.hard_label_weight losses_dict['unweighted_hard_label_loss'] = total_hl_loss return total_loss, losses_dict def cache_logits(self, batch, args, batch_postprocessor): if batch_postprocessor is not None: batch = batch_postprocessor(batch) batch = move_to_device(batch, self.t_config.device) with torch.no_grad(): if type(batch) is dict: results_T = self.model_T(**batch,**args) else: results_T = self.model_T(*batch, **args) results_T = post_adaptor(self.adaptor_T(batch,results_T)) self.logits_cache.append([batch, [logits.to('cpu') for logits in results_T['logits']]])