Commit e8cbaafc authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Change network attribute naming + black formatter

parent cf4a3c63
......@@ -34,7 +34,6 @@ class BaseModel(ABC):
self.optimizers = []
self.visual_names = []
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
The implementation here is just a basic setting of input and label. You may implement
......@@ -43,7 +42,6 @@ class BaseModel(ABC):
self.input = transfer_to_device(input[0], self.device)
self.label = transfer_to_device(input[1], self.device)
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
......@@ -68,9 +66,14 @@ class BaseModel(ABC):
if self.is_train:
self.load_optimizers(last_checkpoint)
for o in self.optimizers:
o.param_groups[0]['lr'] = o.param_groups[0]['initial_lr'] # reset learning rate
o.param_groups[0]['lr'] = o.param_groups[0][
'initial_lr'
] # reset learning rate
self.schedulers = [get_scheduler(optimizer, self.configuration) for optimizer in self.optimizers]
self.schedulers = [
get_scheduler(optimizer, self.configuration)
for optimizer in self.optimizers
]
if last_checkpoint > 0:
for s in self.schedulers:
......@@ -84,14 +87,14 @@ class BaseModel(ABC):
"""Make models train mode during test time."""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name)
net = getattr(self, name.lower())
net.train()
def eval(self):
"""Make models eval mode during test time."""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name)
net = getattr(self, name.lower())
net.eval()
def test(self):
......@@ -102,7 +105,6 @@ class BaseModel(ABC):
with torch.no_grad():
self.forward()
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
......@@ -111,7 +113,6 @@ class BaseModel(ABC):
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = {0:.7f}'.format(lr))
def save_networks(self, epoch):
"""Save all the networks to the disk.
"""
......@@ -119,7 +120,7 @@ class BaseModel(ABC):
if isinstance(name, str):
save_filename = '{0}_net_{1}.pth'.format(epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, name)
net = getattr(self, name.lower())
if self.use_cuda:
torch.save(net.cpu().state_dict(), save_path)
......@@ -127,7 +128,6 @@ class BaseModel(ABC):
else:
torch.save(net.cpu().state_dict(), save_path)
def load_networks(self, epoch):
"""Load all the networks from the disk.
"""
......@@ -135,7 +135,7 @@ class BaseModel(ABC):
if isinstance(name, str):
load_filename = '{0}_net_{1}.pth'.format(epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, name)
net = getattr(self, name.lower())
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from {0}'.format(load_path))
......@@ -145,7 +145,6 @@ class BaseModel(ABC):
net.load_state_dict(state_dict)
def save_optimizers(self, epoch):
"""Save all the optimizers to the disk for restarting training.
"""
......@@ -155,7 +154,6 @@ class BaseModel(ABC):
torch.save(optimizer.state_dict(), save_path)
def load_optimizers(self, epoch):
"""Load all the optimizers from the disk.
"""
......@@ -168,7 +166,6 @@ class BaseModel(ABC):
del state_dict._metadata
optimizer.load_state_dict(state_dict)
def print_networks(self):
"""Print the total number of parameters in the network and network architecture.
"""
......@@ -180,61 +177,65 @@ class BaseModel(ABC):
for param in net.parameters():
num_params += param.numel()
print(net)
print('[Network {0}] Total number of parameters : {1:.3f} M'.format(name, num_params / 1e6))
print(
'[Network {0}] Total number of parameters : {1:.3f} M'.format(
name, num_params / 1e6
)
)
def set_requires_grad(self, requires_grad=False):
"""Set requies_grad for all the networks to avoid unnecessary computations.
"""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name)
net = getattr(self, name.lower())
for param in net.parameters():
param.requires_grad = requires_grad
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console"""
errors_ret = OrderedDict()
errors_ret = dict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret
def pre_epoch_callback(self, epoch):
pass
def post_epoch_callback(self, epoch, visualizer):
pass
def get_hyperparam_result(self):
"""Returns the final training result for hyperparameter tuning (e.g. best
validation loss).
"""
pass
def export(self):
"""Exports all the networks of the model using JIT tracing. Requires that the
input is set.
"""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name)
export_path = os.path.join(self.configuration['export_path'], 'exported_net_{}.pth'.format(name))
if isinstance(self.input, list): # we have to modify the input for tracing
net = getattr(self, name.lower())
export_path = os.path.join(
self.configuration['export_path'],
'exported_net_{}.pth'.format(name),
)
if isinstance(
self.input, list
): # we have to modify the input for tracing
self.input = [tuple(self.input)]
traced_script_module = torch.jit.trace(net, self.input)
traced_script_module.save(export_path)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images."""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
visual_ret[name] = getattr(self, name.lower())
return visual_ret
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment