from .distiller_utils import *
[docs]class BasicTrainer:
"""
It performs supervised training, not distillation. It can be used for training the teacher model.
Args:
train_config (:class:`TrainingConfig`): training configuration.
model (:class:`torch.nn.Module`): model to be trained.
adaptor (Callable):adaptor of the model.
The role of `adaptor` is explained in :py:func:`adaptor`.
"""
def __enter__(self):
self.model_is_training = self.model.training
self.model.train()
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore model status
self.model.train(self.model_is_training)
def __init__(self, train_config: TrainingConfig,
model: torch.nn.Module, adaptor):
super(BasicTrainer, self).__init__()
self.t_config = train_config
self.model = model
self.adaptor = adaptor
self.local_rank = self.t_config.local_rank
self.rank = 0
if self.local_rank != -1:
self.rank = torch.distributed.get_rank()
if self.t_config.log_dir is not None and self.rank == 0:
self.tb_writer = SummaryWriter(log_dir = self.t_config.log_dir)
else:
self.tb_writer = no_op
self.print_freq = 20
[docs] def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, scheduler_args=None, scheduler=None, max_grad_norm = -1.0, num_steps=None, callback=None, batch_postprocessor=None, **args):
"""
trains the model. See :meth:`BasicDistiller.train`.
"""
# update scheduler
if scheduler_class is not None:
# overwrite scheduler
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)
if self.t_config.fp16:
if not has_apex:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=self.t_config.fp16_opt_level)
#Multi-gpu training
if self.local_rank != -1:
self.model = torch.nn.parallel.DistributedDataParallel(self.model,
device_ids = [self.local_rank], output_device = self.local_rank,
find_unused_parameters = True)
elif self.t_config.data_parallel: # exclusive with DDP
self.model = torch.nn.DataParallel(self.model)
tqdm_disable = None if self.rank == 0 else True
if num_steps is not None:
total_global_steps = num_steps
ckpt_steps = int(self.t_config.ckpt_steps)
num_steps = int(num_steps)
print_every = ckpt_steps // self.print_freq
if print_every == 0:
print_every = ckpt_steps
checkpoints = [ i * ckpt_steps for i in range(1,num_steps//ckpt_steps+1)] + [total_global_steps]
logger.info(f"Total training steps: {total_global_steps}")
logger.info(f"Checkpoints: {checkpoints}")
global_step = 0
writer_step = 0
for step, batch in tqdm(enumerate(cycle(dataloader)),disable=tqdm_disable):
if batch_postprocessor is not None:
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
if self.rank == 0:
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
global_step += 1
if (global_step) % print_every == 0:
logger.info(f"Global step: {global_step}, epoch step:{step+1}")
if (global_step%ckpt_steps==0) or global_step==total_global_steps:
if self.rank != 0:
torch.distributed.barrier() # save and eval with single process
else:
logger.info(f"Saving at global step {global_step}")
coreModel = self.model.module if hasattr(self.model, "module") else self.model
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
if self.local_rank == 0: # DDP is enabled
torch.distributed.barrier()
if callback is not None:
logger.info("Running callback function...")
callback(model=self.model, step=global_step)
self.model.train()
if global_step >= total_global_steps:
logger.info("Training finished")
return
train_steps_per_epoch = len(dataloader)//self.t_config.gradient_accumulation_steps
print_every = train_steps_per_epoch // self.print_freq
if print_every == 0:
print_every = train_steps_per_epoch
checkpoints = [int(train_steps_per_epoch*ci/self.t_config.ckpt_frequency) for ci in range(self.t_config.ckpt_frequency)]
logger.info(f"Training steps per epoch: {train_steps_per_epoch}")
logger.info(f"Checkpoints(step): {checkpoints}")
global_step = 0
writer_step = 0
for current_epoch in tqdm(range(int(num_epochs)),disable=tqdm_disable):
if self.local_rank != -1 and hasattr(dataloader,'sampler'):
dataloader.sampler.set_epoch(current_epoch) #In distributed mode, calling the set_epoch method is needed to make shuffling work;
logger.info(f"Epoch {current_epoch+1}")
optimizer.zero_grad()
logger.info(f"Length of current epoch in forward batch: {len(dataloader)}")
for step, batch in tqdm(enumerate(dataloader),disable=tqdm_disable):
if batch_postprocessor is not None:
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
if self.rank == 0:
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
global_step += 1
if (global_step) % print_every == 0:
logger.info(f"Global step: {global_step}, epoch step:{step+1}")
if (global_step%train_steps_per_epoch in checkpoints) \
and ((current_epoch+1)%self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs):
if self.rank != 0:
torch.distributed.barrier() # save and eval with single process
else:
logger.info(f"Saving at global step {global_step}, epoch step {step+1} epoch {current_epoch+1}")
coreModel = self.model.module if hasattr(self.model, "module") else self.model
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
if self.local_rank == 0: # DDP is enabled
torch.distributed.barrier()
if callback is not None:
logger.info("Running callback function...")
callback(model=self.model, step=global_step)
self.model.train()
logger.info(f"Epoch {current_epoch+1} finished")
def train_on_batch(self, batch, args) -> torch.Tensor:
batch = move_to_device(batch, self.t_config.device)
if type(batch) is dict:
results = self.model(**batch,**args)
else:
results = self.model(*batch, **args)
results = post_adaptor(self.adaptor(batch,results))
total_loss = 0
if 'losses' not in results:
raise KeyError("'losses' not in the output of adaptor. Nothing to optimize!")
else:
for loss in results['losses']:
# in case of multi-GPU
total_loss += loss.mean()
return total_loss