Commit f01e1734 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

updated model save/load to include schedule sampling settings

parent 4c92f32b
......@@ -130,9 +130,12 @@ def schedule_sampling(eta: float, itr: int):
def wrapper_train(model: Model):
itr = 1
eta = args.sampling_start_value
if args.pretrained_model:
itr, eta = model.load(args.pretrained_model)
iterator = get_batcher(args)
progress_bar = tqdm(range(1, args.max_iterations + 1))
progress_bar = tqdm(range(itr, args.max_iterations + 1))
for itr in progress_bar:
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
......@@ -143,7 +146,7 @@ def wrapper_train(model: Model):
progress_bar.set_description(f"Loss: {cost}")
if itr % args.test_interval == 0:
model.save(itr)
model.save(itr, eta)
print("Initializing models")
......
......@@ -39,24 +39,33 @@ class Model(object):
self.MSE_criterion = nn.MSELoss(size_average=False)
self.MAE_criterion = nn.L1Loss(size_average=False)
def save(self, itr):
state_dict = {
def save(self, itr, eta):
if self.configs.is_parallel:
model_states = self.network.module.state_dict()
else:
model_states = self.network.state_dict()
save_dict = {
'iter': itr,
'model_state_dict': self.network.state_dict(),
'eta': eta,
'model_state_dict': model_states,
'optimizer_state_dict': self.optimizer.state_dict(),
"net_param": self.network.state_dict()
}
save_path = Path("/") / "data1" / "IDA_LSTM_checkpoints"
save_path.mkdir(exist_ok=True)
save_path = save_path / f"{itr}.pth"
torch.save(state_dict, save_path)
torch.save(save_dict, save_path)
def load(self, path):
assert os.path.exists(path), "Weights dir does not exist"
stats = torch.load(path, map_location=torch.device(self.configs.device))
self.network.load_state_dict(stats["model_state_dict"])
if self.configs.is_parallel:
self.network.module.load_state_dict(stats["model_state_dict"])
else:
self.network.load_state_dict(stats["model_state_dict"])
self.optimizer.load_state_dict(stats["optimizer_state_dict"])
print("Model loaded")
return stats['iter'], stats['eta']
def train(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment