Commit 104bdaa0 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

Move testing on cpu

parent 2cf9d0e9
......@@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda") # cuda
parser.add_argument("--n_gpu", type=int, default=0)
# data
parser.add_argument("--dataset_name", type=str, default="radar")
......@@ -53,7 +54,6 @@ parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--test_interval", type=int, default=5000) # 5000
# parser.add_argument("--snapshot_interval", type=int, default=5000)
parser.add_argument("--num_save_samples", type=int, default=10)
parser.add_argument("--n_gpu", type=int, default=0)
args = parser.parse_args()
batch_size = args.batch_size
......
......@@ -2,27 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"source": [
"import os\n",
"\n",
"import argparse\n",
"import math\n",
"import shutil\n",
"\n",
"import numpy as np\n",
"\n",
"import core.trainer as trainer\n",
"from core.models.model_factory import Model\n",
"from core.utils import preprocess\n",
"from data_provider.CIKM.data_iterator import clean_fold, sample, imsave\n",
"from core.utils.util import nor, de_nor\n",
"from data_provider.CIKM.taasss import infinite_batcher\n",
"from pathlib import Path\n",
"import h5py\n",
"import cv2\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from matplotlib import pyplot as plt\n",
"from cikm_inter_dst_predrnn_run_taasss_utils import (\n",
" padding_taasss,\n",
......@@ -35,7 +21,8 @@
")\n",
"\n",
"parser.add_argument(\"--is_training\", type=int, default=1)\n",
"parser.add_argument(\"--device\", type=str, default=\"cuda\") # cuda\n",
"parser.add_argument(\"--device\", type=str, default=\"cpu\") # cuda\n",
"parser.add_argument(\"--n_gpu\", type=int, default=0)\n",
"\n",
"# data\n",
"parser.add_argument(\"--dataset_name\", type=str, default=\"radar\")\n",
......@@ -44,8 +31,8 @@
"parser.add_argument(\"--save_dir\", type=str, default=\"checkpoints/inter_dst_predrnn\")\n",
"parser.add_argument(\"--gen_frm_dir\", type=str, default=\"dataset_generated/\")\n",
"parser.add_argument(\"--input_length\", type=int, default=5)\n",
"parser.add_argument(\"--total_length\", type=int, default=25) # 15\n",
"parser.add_argument(\"--img_width\", type=int, default=512) # 512\n",
"parser.add_argument(\"--total_length\", type=int, default=25)\n",
"parser.add_argument(\"--img_width\", type=int, default=512)\n",
"parser.add_argument(\"--img_channel\", type=int, default=1)\n",
"\n",
"# model\n",
......@@ -66,13 +53,12 @@
"# optimization\n",
"parser.add_argument(\"--lr\", type=float, default=0.001)\n",
"parser.add_argument(\"--reverse_input\", type=int, default=1)\n",
"parser.add_argument(\"--batch_size\", type=int, default=1) # 4\n",
"parser.add_argument(\"--max_iterations\", type=int, default=80000)\n",
"parser.add_argument(\"--display_interval\", type=int, default=1) # 200\n",
"parser.add_argument(\"--test_interval\", type=int, default=1) # 2000\n",
"parser.add_argument(\"--snapshot_interval\", type=int, default=5000)\n",
"parser.add_argument(\"--num_save_samples\", type=int, default=10)\n",
"parser.add_argument(\"--n_gpu\", type=int, default=0)\n",
"parser.add_argument(\"--batch_size\", type=int, default=1)\n",
"#parser.add_argument(\"--max_iterations\", type=int, default=80000)\n",
"#parser.add_argument(\"--display_interval\", type=int, default=1)\n",
"#parser.add_argument(\"--test_interval\", type=int, default=1)\n",
"#parser.add_argument(\"--snapshot_interval\", type=int, default=5000)\n",
"#parser.add_argument(\"--num_save_samples\", type=int, default=10)\n",
"\n",
"args, unknown = parser.parse_known_args()\n",
"batch_size = args.batch_size\n",
......@@ -94,8 +80,7 @@
"output_type": "stream",
"name": "stdout",
"text": [
"Initializing models\n",
"Model loaded\n"
"Initializing models\n"
]
},
{
......@@ -105,16 +90,22 @@
"/home/meteotn/IDA_LSTM/.venv/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
" warnings.warn(warning.format(ret))\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model loaded\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"source": [
"def wrapper_test(model: Model):\n",
" test_save_root = args.gen_frm_dir\n",
" loss = 0\n",
" real_input_flag = np.zeros(\n",
" (\n",
......@@ -128,7 +119,7 @@
" output_length = args.total_length - args.input_length\n",
" iterator = get_batcher(args)\n",
" steps = 2\n",
" for index in range(steps):\n",
" for _ in range(steps):\n",
" dat = next(iterator)\n",
" # (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)\n",
" dat = np.squeeze(dat)\n",
......@@ -148,7 +139,6 @@
" # Should already by 0 to 1\n",
" # img_out = de_nor(img_out)\n",
" loss += mse\n",
" bat_ind = 0\n",
"\n",
" img_out = img_out.squeeze()\n",
" dat = dat.squeeze()\n",
......@@ -165,7 +155,7 @@
" return loss / steps\n",
"\n",
"\n",
"wrapper_test(model)\n"
"wrapper_test(model)"
],
"outputs": [
{
......@@ -183,11 +173,11 @@
"output_type": "execute_result",
"data": {
"text/plain": [
"0.00010389353701611981"
"0.00010389316958026029"
]
},
"metadata": {},
"execution_count": 2
"execution_count": 3
},
{
"output_type": "display_data",
......
%% Cell type:code id: tags:
```
import os
import argparse
import math
import shutil
import numpy as np
import core.trainer as trainer
from core.models.model_factory import Model
from core.utils import preprocess
from data_provider.CIKM.data_iterator import clean_fold, sample, imsave
from core.utils.util import nor, de_nor
from data_provider.CIKM.taasss import infinite_batcher
from pathlib import Path
import h5py
import cv2
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from cikm_inter_dst_predrnn_run_taasss_utils import (
padding_taasss,
unpadding_taasss,
get_batcher,
)
parser = argparse.ArgumentParser(
description="PyTorch video prediction model - DST PredRNN"
)
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda") # cuda
parser.add_argument("--device", type=str, default="cpu") # cuda
parser.add_argument("--n_gpu", type=int, default=0)
# data
parser.add_argument("--dataset_name", type=str, default="radar")
parser.add_argument("--r", type=int, default=4)
parser.add_argument("--is_parallel", type=bool, default=False)
parser.add_argument("--save_dir", type=str, default="checkpoints/inter_dst_predrnn")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=25) # 15
parser.add_argument("--img_width", type=int, default=512) # 512
parser.add_argument("--total_length", type=int, default=25)
parser.add_argument("--img_width", type=int, default=512)
parser.add_argument("--img_channel", type=int, default=1)
# model
parser.add_argument("--model_name", type=str, default="interact_dst_predrnn")
parser.add_argument("--pretrained_model", type=str, default="")
parser.add_argument("--num_hidden", type=str, default="64,64,64,64")
parser.add_argument("--filter_size", type=int, default=5)
parser.add_argument("--stride", type=int, default=1)
# parser.add_argument("--patch_size", type=int, default=4)
parser.add_argument("--layer_norm", type=int, default=1)
# scheduled sampling
parser.add_argument("--scheduled_sampling", type=int, default=1)
parser.add_argument("--sampling_stop_iter", type=int, default=50000)
parser.add_argument("--sampling_start_value", type=float, default=1.0)
parser.add_argument("--sampling_changing_rate", type=float, default=0.00002)
# optimization
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--reverse_input", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) # 4
parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--display_interval", type=int, default=1) # 200
parser.add_argument("--test_interval", type=int, default=1) # 2000
parser.add_argument("--snapshot_interval", type=int, default=5000)
parser.add_argument("--num_save_samples", type=int, default=10)
parser.add_argument("--n_gpu", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1)
#parser.add_argument("--max_iterations", type=int, default=80000)
#parser.add_argument("--display_interval", type=int, default=1)
#parser.add_argument("--test_interval", type=int, default=1)
#parser.add_argument("--snapshot_interval", type=int, default=5000)
#parser.add_argument("--num_save_samples", type=int, default=10)
args, unknown = parser.parse_known_args()
batch_size = args.batch_size
if args.img_width == 128:
args.patch_size = 4
elif args.img_width == 256:
args.patch_size = 8
elif args.img_width == 512:
args.patch_size = 16
print("Initializing models")
model = Model(args)
model_path = Path("/") / "data1" / "IDA_LSTM_checkpoints" / "80000.pth"
model.load(model_path)
```
%%%% Output: stream
Initializing models
Model loaded
%%%% Output: stream
/home/meteotn/IDA_LSTM/.venv/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
%%%% Output: stream
Model loaded
%% Cell type:code id: tags:
```
def wrapper_test(model: Model):
test_save_root = args.gen_frm_dir
loss = 0
real_input_flag = np.zeros(
(
args.batch_size,
args.total_length - args.input_length - 1,
args.img_width // args.patch_size,
args.img_width // args.patch_size,
args.patch_size ** 2 * args.img_channel,
)
)
output_length = args.total_length - args.input_length
iterator = get_batcher(args)
steps = 2
for index in range(steps):
for _ in range(steps):
dat = next(iterator)
# (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)
dat = np.squeeze(dat)
dat = np.expand_dims(dat, axis=0)
dat = np.expand_dims(dat, axis=4)
# Should already by 0 to 1
# dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_taasss(dat, args)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = ims.astype(np.float64)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_taasss(img_gen[:, -output_length:], args)
mse = np.mean(np.square(tars - img_out))
# Should already by 0 to 1
# img_out = de_nor(img_out)
loss += mse
bat_ind = 0
img_out = img_out.squeeze()
dat = dat.squeeze()
fig, axs = plt.subplots(1, 10, figsize=(20, 2))
for i in range(5):
# 5 frames of input
axs[i].set_title(f"Input {i}")
axs[i].imshow(dat[i])
for i in range(5, 10):
# First 5 frames of output
axs[i].set_title(f"Output {i-5}")
axs[i].imshow(img_out[i])
fig.show()
return loss / steps
wrapper_test(model)
```
%%%% Output: stream
start_datetime 2016-12-11 00:00:00
end_datetime 2016-12-11 23:55:00
run_length 288
avg_cell_value 0.221938
Name: 3126, dtype: object
%%%% Output: execute_result
0.00010389353701611981
0.00010389316958026029
%%%% Output: display_data
%%%% Output: display_data
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment