import os
import math
import time
import torch
import datetime
import threading
from ..lib import utils
from ..lib.options import Options
from ..lib.logger import Logger

[docs]class Engine(object): """Contains training and evaluation procedures """ def __init__(self): self.hooks = {} self.epoch = 0 self.dataset = None self.model = None self.optimizer = None self.view = None self.best_out = {} # generate_view will be executed at the end of each # training and evaluation epoch self.register_hook('train_on_flush', self.generate_view) self.register_hook('eval_on_flush', self.generate_view)
[docs] def generate_view(self): """ Generate a view.html via an asynchronous call to `self.view.generate()` """ if self.view is not None: threading.Thread(target=self.view.generate).start()
# path_opts = os.path.join(Options()['exp']['dir'], 'options.yaml') # os.system('python -m bootstrap.views.view --path_opts {}'.format(path_opts))
[docs] def load_state_dict(self, state): """ """ self.epoch = state['epoch'] self.best_out = state['best_out']
[docs] def state_dict(self): """ """ state = {} state['epoch'] = self.epoch state['best_out'] = self.best_out return state
[docs] def hook(self, name): """ Run all the callback functions that have been registered for a hook. Args: name: the name of the hook """ if name in self.hooks: for func in self.hooks[name]: func()
[docs] def register_hook(self, name, func): """ Register a callback function to be triggered when the hook is called. Args: name: the name of the hook func: the callback function (no argument) Example usage: .. code-block:: python def func(): print('hooked!') engine.register_hook('train_on_start_batch', func) """ if name not in self.hooks: self.hooks[name] = [] self.hooks[name].append(func)
[docs] def resume(self): """ Resume a checkpoint using the `bootstrap.lib.options.Options` """ Logger()('Loading {} checkpoint'.format(Options()['exp']['resume'])) self.load(Options()['exp']['dir'], Options()['exp']['resume'], self.model, self.optimizer) self.epoch += 1
[docs] def eval(self): """ Launch evaluation procedures """ Logger()('Launching evaluation procedures') if Options()['dataset']['eval_split']: # self.epoch-1 to be equal to the same resumed epoch # or to be equal to -1 when not resumed self.eval_epoch(self.model, self.dataset['eval'], self.epoch-1, logs_json=True) Logger()('Ending evaluation procedures')
[docs] def train(self): """ Launch training procedures List of the hooks: - train_on_start: before the full training procedure """ Logger()('Launching training procedures') self.hook('train_on_start') while self.epoch < Options()['engine']['nb_epochs']: self.train_epoch(self.model, self.dataset['train'], self.optimizer, self.epoch) if Options()['dataset']['eval_split']: out = self.eval_epoch(self.model, self.dataset['eval'], self.epoch) if 'saving_criteria' in Options()['engine'] and Options()['engine']['saving_criteria'] is not None: for saving_criteria in Options()['engine']['saving_criteria']: if self.is_best(out, saving_criteria): name = saving_criteria.split(':')[0] Logger()('Saving best checkpoint for strategy {}'.format(name))['exp']['dir'], 'best_{}'.format(name), self.model, self.optimizer) Logger()('Saving last checkpoint')['exp']['dir'], 'last', self.model, self.optimizer) self.epoch += 1 Logger()('Ending training procedures')
[docs] def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): """ Launch training procedures for one epoch List of the hooks: - train_on_start_epoch: before the training procedure for an epoch - train_on_start_batch: before the training precedure for a batch - train_on_forward: after the forward of the model - train_on_bachward: after the backward of the loss - train_on_update: after the optimization step - train_on_print: after the print to the terminal - train_on_end_batch: end of the training procedure for a batch - train_on_end_epoch: before saving the logs in logs.json - train_on_flush: end of the training procedure for an epoch """ utils.set_random_seed(Options()['misc']['seed'] + epoch) # to be able to reproduce exps on reload Logger()('Training model on {}set for epoch {}'.format(dataset.split, epoch)) model.train() timer = { 'begin': time.time(), 'elapsed': time.time(), 'process': None, 'load': None, 'run_avg': 0 } out_epoch = {} batch_loader = dataset.make_batch_loader() self.hook('train_on_start_epoch') for i, batch in enumerate(batch_loader): timer['load'] = time.time() - timer['elapsed'] self.hook('train_on_start_batch') optimizer.zero_grad() out = model(batch) self.hook('train_on_forward') out['loss'].backward() #torch.cuda.synchronize() self.hook('train_on_backward') optimizer.step() #torch.cuda.synchronize() self.hook('train_on_update') timer['process'] = time.time() - timer['elapsed'] if i == 0: timer['run_avg'] = timer['process'] else: timer['run_avg'] = timer['run_avg'] * 0.8 + timer['process'] * 0.2 Logger().log_value('train_batch.epoch', epoch, should_print=False) Logger().log_value('train_batch.batch', i, should_print=False) Logger().log_value('train_batch.timer.process', timer['process'], should_print=False) Logger().log_value('train_batch.timer.load', timer['load'], should_print=False) for key, value in out.items(): if torch.is_tensor(value): if value.dim() <= 1: value = value.item() # get number from a torch scalar else: continue if type(value) == list: continue if type(value) == dict: continue if key not in out_epoch: out_epoch[key] = [] out_epoch[key].append(value) Logger().log_value('train_batch.'+key, value, should_print=False) if i % Options()['engine']['print_freq'] == 0: Logger()("{}: epoch {} | batch {}/{}".format(mode, epoch, i, len(batch_loader) - 1)) Logger()("{} elapsed: {} | left: {}".format(' '*len(mode), datetime.timedelta(seconds=math.floor(time.time() - timer['begin'])), datetime.timedelta(seconds=math.floor(timer['run_avg'] * (len(batch_loader) - 1 - i))))) Logger()("{} process: {:.5f} | load: {:.5f}".format(' '*len(mode), timer['process'], timer['load'])) Logger()("{} loss: {:.5f}".format(' '*len(mode), out['loss'].data.item())) self.hook('train_on_print') timer['elapsed'] = time.time() self.hook('train_on_end_batch') if Options()['engine']['debug']: if i > 2: break Logger().log_value('train_epoch.epoch', epoch, should_print=True) for key, value in out_epoch.items(): Logger().log_value('train_epoch.'+key, sum(value)/len(value), should_print=True) self.hook('train_on_end_epoch') Logger().flush() self.hook('train_on_flush')
[docs] def eval_epoch(self, model, dataset, epoch, mode='eval', logs_json=True): """ Launch evaluation procedures for one epoch List of the hooks (``mode='eval'`` by default): - mode_on_start_epoch: before the evaluation procedure for an epoch - mode_on_start_batch: before the evaluation precedure for a batch - mode_on_forward: after the forward of the model - mode_on_print: after the print to the terminal - mode_on_end_batch: end of the evaluation procedure for a batch - mode_on_end_epoch: before saving the logs in logs.json - mode_on_flush: end of the evaluation procedure for an epoch Returns: out(dict): mean of all the scalar outputs of the model, indexed by output name, for this epoch """ utils.set_random_seed(Options()['misc']['seed'] + epoch) # to be able to reproduce exps on reload Logger()('Evaluating model on {}set for epoch {}'.format(dataset.split, epoch)) model.eval() timer = { 'begin': time.time(), 'elapsed': time.time(), 'process': None, 'load': None, 'run_avg': 0 } out_epoch = {} batch_loader = dataset.make_batch_loader() self.hook('{}_on_start_epoch'.format(mode)) for i, batch in enumerate(batch_loader): timer['load'] = time.time() - timer['elapsed'] self.hook('{}_on_start_batch'.format(mode)) with torch.no_grad(): out = model(batch) #torch.cuda.synchronize() self.hook('{}_on_forward'.format(mode)) timer['process'] = time.time() - timer['elapsed'] if i == 0: timer['run_avg'] = timer['process'] else: timer['run_avg'] = timer['run_avg'] * 0.8 + timer['process'] * 0.2 Logger().log_value('{}_batch.batch'.format(mode), i, should_print=False) Logger().log_value('{}_batch.epoch'.format(mode), epoch, should_print=False) Logger().log_value('{}_batch.timer.process'.format(mode), timer['process'], should_print=False) Logger().log_value('{}_batch.timer.load'.format(mode), timer['load'], should_print=False) for key, value in out.items(): if torch.is_tensor(value): if value.dim() <= 1: value = value.item() # get number from a torch scalar else: continue if type(value) == list: continue if type(value) == dict: continue if key not in out_epoch: out_epoch[key] = [] out_epoch[key].append(value) Logger().log_value('{}_batch.{}'.format(mode, key), value, should_print=False) if i % Options()['engine']['print_freq'] == 0: Logger()("{}: epoch {} | batch {}/{}".format(mode, epoch, i, len(batch_loader) - 1)) Logger()("{} elapsed: {} | left: {}".format(' '*len(mode), datetime.timedelta(seconds=math.floor(time.time() - timer['begin'])), datetime.timedelta(seconds=math.floor(timer['run_avg'] * (len(batch_loader) - 1 - i))))) Logger()("{} process: {:.5f} | load: {:.5f}".format(' '*len(mode), timer['process'], timer['load'])) self.hook('{}_on_print'.format(mode)) timer['elapsed'] = time.time() self.hook('{}_on_end_batch'.format(mode)) if Options()['engine']['debug']: if i > 10: break out = {} for key, value in out_epoch.items(): try: out[key] = sum(value)/len(value) except: import ipdb; ipdb.set_trace() Logger().log_value('{}_epoch.epoch'.format(mode), epoch, should_print=True) for key, value in out.items(): Logger().log_value('{}_epoch.{}'.format(mode, key), value, should_print=True) self.hook('{}_on_end_epoch'.format(mode)) if logs_json: Logger().flush() self.hook('{}_on_flush'.format(mode)) return out
[docs] def is_best(self, out, saving_criteria): """ Verify if the last model is the best for a specific saving criteria Args: out(dict): mean of all the scalar outputs of model indexed by output name saving_criteria(str): Returns: is_best(bool) Example usage: .. code-block:: python out = { 'loss': 0.2, 'acctop1': 87.02 } engine.is_best(out, 'loss:min') """ if ':min' in saving_criteria: name = saving_criteria.replace(':min', '') order = '<' elif ':max' in saving_criteria: name = saving_criteria.replace(':max', '') order = '>' else: error_msg = """'--engine.saving_criteria' named '{}' does not specify order, you need to chose between '{}' or '{}' to specify if the criteria needs to be minimize or maximize""".format( saving_criteria, saving_criteria+':min', saving_criteria+':max') raise ValueError(error_msg) if name not in out: raise KeyError("'--engine.saving_criteria' named '{}' not in outputs '{}'".format(name, list(out.keys()))) if name not in self.best_out: self.best_out[name] = out[name] else: if eval('{} {} {}'.format(out[name], order, self.best_out[name])): self.best_out[name] = out[name] return True return False
[docs] def load(self, dir_logs, name, model, optimizer): """ Load a checkpoint Args: dir_logs: directory of the checkpoint name: name of the checkpoint model: model associated to the checkpoint optimizer: optimizer associated to the checkpoint """ path_template = os.path.join(dir_logs, 'ckpt_{}_{}.pth.tar') Logger()('Loading model...') model_state = torch.load(path_template.format(name, 'model')) model.load_state_dict(model_state) if Options()['dataset']['train_split'] is not None: if os.path.isfile(path_template.format(name, 'optimizer')): Logger()('Loading optimizer...') optimizer_state = torch.load(path_template.format(name, 'optimizer')) optimizer.load_state_dict(optimizer_state) else: Logger()('No optimizer checkpoint', log_level=Logger.WARNING) if os.path.isfile(path_template.format(name, 'engine')): Logger()('Loading engine...') engine_state = torch.load(path_template.format(name, 'engine')) self.load_state_dict(engine_state) else: Logger()('No engine checkpoint', log_level=Logger.WARNING)
[docs] def save(self, dir_logs, name, model, optimizer): """ Save a checkpoint Args: dir_logs: directory of the checkpoint name: name of the checkpoint model: model associated to the checkpoint optimizer: optimizer associated to the checkpoint """ path_template = os.path.join(dir_logs, 'ckpt_{}_{}.pth.tar') Logger()('Saving model...') model_state = model.state_dict(), path_template.format(name, 'model')) Logger()('Saving optimizer...') optimizer_state = optimizer.state_dict(), path_template.format(name, 'optimizer')) Logger()('Saving engine...') engine_state = self.state_dict(), path_template.format(name, 'engine'))