Commit 4c92f32b authored by Gabriele Franch's avatar Gabriele Franch
Browse files

updated model save/load to include optimizer state

parent dec5882e
......@@ -40,18 +40,22 @@ class Model(object):
self.MAE_criterion = nn.L1Loss(size_average=False)
def save(self, itr):
stats = {}
stats["net_param"] = self.network.state_dict()
state_dict = {
'iter': itr,
'model_state_dict': self.network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
"net_param": self.network.state_dict()
}
save_path = Path("/") / "data1" / "IDA_LSTM_checkpoints"
save_path.mkdir(exist_ok=True)
save_path = save_path / f"{itr}.pth"
torch.save(stats, save_path)
# print(f"Saving model to {save_path}")
torch.save(state_dict, save_path)
def load(self, path):
assert os.path.exists(path), "Weights dir does not exist"
stats = torch.load(path, map_location=torch.device(self.configs.device))
self.network.load_state_dict(stats["net_param"])
self.network.load_state_dict(stats["model_state_dict"])
self.optimizer.load_state_dict(stats["optimizer_state_dict"])
print("Model loaded")
def train(self, frames, mask):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment