import torch.nn as nn
from ..lib.logger import Logger
from ..datasets import transforms
from .networks.factory import factory as net_factory
from .criterions.factory import factory as cri_factory
from .metrics.factory import factory as met_factory
[docs]class Model(nn.Module):
""" Model contains a network, two criterions (train, eval) and two metrics.
"""
def __init__(self,
engine=None,
cuda_tf=transforms.ToCuda,
detach_tf=transforms.ToDetach,
network=None,
criterions={},
metrics={}):
super(Model, self).__init__()
self.cuda_tf = cuda_tf
self.detach_tf = detach_tf
self.network = network
self.criterions = criterions
self.metrics = metrics
self.is_cuda = False
self.eval()
[docs] def eval(self):
""" Activate evaluation mode
"""
super(Model, self).train(mode=False)
self.mode = 'eval'
[docs] def train(self):
""" Activate training mode
"""
super(Model, self).train(mode=True)
self.mode = 'train'
[docs] def cuda(self, device_id=None):
""" Moves all model parameters and buffers to the GPU.
Args:
device_id (int, optional): if specified, all parameters will be
copied to that device
"""
self.is_cuda = True
return self._apply(lambda t: t.cuda(device_id))
[docs] def cpu(self):
""" Moves all model parameters and buffers to the CPU.
"""
self.is_cuda = False
return self._apply(lambda t: t.cpu())
[docs] def prepare_batch(self, batch):
""" Prepare a batch with two functions: cuda_tf and detach_tf (only in eval mode)
"""
if self.is_cuda:
batch = self.cuda_tf()(batch)
if self.mode == 'eval':
batch = self.detach_tf()(batch)
return batch
[docs] def forward(self, batch):
""" Prepare the batch and feed it to the network, criterion and metric.
Returns:
out (dict): a dictionary of outputs
"""
batch = self.prepare_batch(batch)
net_out = self.network(batch)
cri_out = {}
if self.mode in self.criterions:
cri_tmp = self.criterions[self.mode](net_out, batch)
if cri_tmp is not None:
cri_out = cri_tmp
met_out = {}
if self.mode in self.metrics:
met_tmp = self.metrics[self.mode](cri_out, net_out, batch)
if met_tmp is not None:
met_out = met_tmp
out = {}
if type(net_out) is dict:
for key, value in net_out.items():
out[key] = value
if type(cri_out) is dict:
for key, value in cri_out.items():
out[key] = value
if type(met_out) is dict:
for key, value in met_out.items():
out[key] = value
return out
[docs] def state_dict(self, *args, **kwgs):
"""
"""
state = {}
state['network'] = self.network.state_dict(*args, **kwgs)
state['criterions'] = {}
for mode, criterion in self.criterions.items():
if hasattr(criterion, '__parameters'):
state['criterions'][mode] = criterion.state_dict(*args, **kwgs)
state['metrics'] = {}
for mode, metric in self.metrics.items():
if hasattr(metric, '__parameters'):
state['metrics'][mode] = metric.state_dict(*args, **kwgs)
return state
[docs] def load_state_dict(self, state, *args, **kwgs):
"""
"""
self.network.load_state_dict(state['network'], *args, **kwgs)
for mode, criterion in self.criterions.items():
if hasattr(criterion, '__parameters'):
criterion.load_state_dict(state['criterions'][mode], *args, **kwgs)
for mode, metric in self.metrics.items():
if hasattr(metric, '__parameters'):
metric.load_state_dict(state['metrics'][mode], *args, **kwgs)
[docs]class DefaultModel(Model):
""" An extension of Model that relies on factory calls
"""
def __init__(self, engine=None,
cuda_tf=transforms.ToCuda,
detach_tf=transforms.ToDetach):
super(DefaultModel, self).__init__(
engine=engine,
cuda_tf=cuda_tf,
detach_tf=detach_tf)
self.network = self._init_network(engine=engine)
self.criterions = self._init_criterions(engine=engine)
self.metrics = self._init_metrics(engine=engine)
self.eval()
def _init_network(self, engine=None):
""" Create the network using the bootstrap network factory
"""
return net_factory(engine)
def _init_criterions(self, engine=None):
""" Create the two criterions using the bootstrap criterion factory
"""
# by default all modes have criterions
if engine:
modes = list(engine.dataset.keys()) # [train, val] for mnist
else:
modes = ['train', 'eval']
criterions = {}
for mode in modes:
tmp_cri = cri_factory(engine, mode)
if tmp_cri is not None:
criterions[mode] = tmp_cri
return criterions
def _init_metrics(self, engine=None):
""" Create the two metrics using the bootstrap metric factory
"""
# by default all modes have metrics
if engine:
modes = list(engine.dataset.keys())
else:
modes = ['train', 'eval']
metrics = {}
for mode in modes:
tmp_met = met_factory(engine, mode)
if tmp_met is not None:
metrics[mode] = tmp_met
return metrics
[docs]class SimpleModel(DefaultModel):
""" An extension of DefaultModel that modifies the forward function
"""
def __init__(self, engine=None,
cuda_tf=transforms.ToCuda,
detach_tf=transforms.ToDetach):
super(SimpleModel, self).__init__(
engine=engine,
cuda_tf=cuda_tf,
detach_tf=detach_tf)
[docs] def forward(self, batch):
""" The forward call to the network uses batch['data'] instead of batch
"""
batch = self.prepare_batch(batch)
net_out = self.network(batch['data'])
cri_out = {}
if self.mode in self.criterions:
cri_tmp = self.criterions[self.mode](net_out, batch)
if cri_tmp is not None:
cri_out = cri_tmp
met_out = {}
if self.mode in self.metrics:
met_tmp = self.metrics[self.mode](cri_out, net_out, batch)
if met_tmp is not None:
met_out = met_tmp
out = {}
if type(net_out) is dict:
for key, value in net_out.items():
out[key] = value
if type(cri_out) is dict:
for key, value in cri_out.items():
out[key] = value
if type(met_out) is dict:
for key, value in met_out.items():
out[key] = value
return out