Commit c03b92f6 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Use Path for save_dir and specify filename for weights saving

parent 43c5b9fc
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
import torch
......@@ -21,7 +22,9 @@ class BaseModel(ABC):
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define these lists:
-- self.network_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group
them. See cycle_gan_model.py for an example.
"""
self.configuration = configuration
self.is_train = configuration['is_train']
......@@ -37,6 +40,18 @@ class BaseModel(ABC):
self.optimizers = []
self.visual_names = []
def __setattr__(self, name, value):
if name == 'save_dir':
if not isinstance(value, Path):
value = Path(value)
super().__setattr__(name, value)
if not value.exists():
os.makedirs(value)
else:
super().__setattr__(name, value)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
The implementation here is just a basic setting of input and label. You may implement
......@@ -116,13 +131,18 @@ class BaseModel(ABC):
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = {0:.7f}'.format(lr))
def save_networks(self, epoch):
def save_networks(self, epoch, filename=None):
"""Save all the networks to the disk.
"""
for name in self.network_names:
if isinstance(name, str):
if filename is None:
if filename is None:
save_filename = '{0}_net_{1}.pth'.format(epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
else:
save_filename = filename
save_path = self.save_dir / 'weights' / save_filename
net = getattr(self, name.lower())
if self.use_cuda:
......@@ -137,7 +157,7 @@ class BaseModel(ABC):
for name in self.network_names:
if isinstance(name, str):
load_filename = '{0}_net_{1}.pth'.format(epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
load_path = self.save_dir / 'weights' / load_filename
net = getattr(self, name.lower())
if isinstance(net, torch.nn.DataParallel):
net = net.module
......@@ -153,7 +173,7 @@ class BaseModel(ABC):
"""
for i, optimizer in enumerate(self.optimizers):
save_filename = '{0}_optimizer_{1}.pth'.format(epoch, i)
save_path = os.path.join(self.save_dir, save_filename)
save_path = self.save_dir / 'weights' / save_filename
torch.save(optimizer.state_dict(), save_path)
......@@ -162,7 +182,7 @@ class BaseModel(ABC):
"""
for i, optimizer in enumerate(self.optimizers):
load_filename = '{0}_optimizer_{1}.pth'.format(epoch, i)
load_path = os.path.join(self.save_dir, load_filename)
load_path = self.save_dir / 'weights' / load_filename
print('loading the optimizer from {0}'.format(load_path))
state_dict = torch.load(load_path)
if hasattr(state_dict, '_metadata'):
......
......@@ -67,6 +67,9 @@ class ClassificationTileModel(BaseModel):
self.val_labels = []
# self.val_images = []
def __setattr__(self, name, value):
super().__setattr__(name, value)
def forward(self, features_extraction_only=False):
"""Run forward pass.
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment