Commit 83625e31 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

add print statement

parent c29bc1a2
......@@ -24,7 +24,7 @@ parser.add_argument("--n_gpu", type=int, default=0)
parser.add_argument("--dataset_name", type=str, default="radar")
parser.add_argument("--r", type=int, default=4)
parser.add_argument("--is_parallel", type=bool, default=False)
# parser.add_argument("--save_dir", type=str, default="checkpoints/inter_dst_predrnn")
parser.add_argument("--save_dir", type=str, default="/data1/IDA_LSTM_checkpoints")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=25) # 15
......@@ -139,6 +139,7 @@ def wrapper_train(model: Model):
if file.endswith(".pth"):
maxiter = max([int(file[:-4]), maxiter])
args.pretrained_model = f'{args.pretrained_model}/{maxiter}.pth'
print(f"Loading pretrained model {args.pretrained_model}")
itr, eta = model.load(args.pretrained_model)
iterator = get_batcher(args)
......
......@@ -51,7 +51,7 @@ class Model(object):
'model_state_dict': model_states,
'optimizer_state_dict': self.optimizer.state_dict(),
}
save_path = Path("/") / "data1" / "IDA_LSTM_checkpoints"
save_path = Path(self.configs.save_dir)
save_path.mkdir(exist_ok=True)
save_path = save_path / f"{itr}.pth"
torch.save(save_dict, save_path)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment