Commit 67aae446 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Add singularity support to docker file

parent ad017f30
.git
.venv
dataset
\ No newline at end of file
dataset/
# Singularity file
ida-lstm
......@@ -10,3 +10,6 @@ dataset_generated/
__pycache__/
.pyc
.ipynb_checkpoints
# Singularity file
ida-lstm
\ No newline at end of file
FROM nvcr.io/nvidia/pytorch:21.05-py3
WORKDIR /app
# Workdir not used by singularity
# WORKDIR /app
# Install opencv dependencies
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6
COPY requirements.txt requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
CMD python cikm_inter_dst_predrnn_run_taasss.py
\ No newline at end of file
COPY . /app
CMD python /app/cikm_inter_dst_predrnn_run_taasss.py
\ No newline at end of file
# IDA-LSTM
This is a Pytorch implementation of IDA-LSTM, a recurrent model for radar echo extrapolation (precipitation nowcasting) as described in the following paper:
A Novel LSTM Model with Interaction Dual Attention forRadar Echo Extrapolation, by Chuyao Luo, Xutao Li, Yongliang Wen, Yunming Ye, Xiaofeng Zhang.
# Setup
Required python libraries: torch (>=1.3.0) + opencv + numpy + scipy (== 1.0.0) + jpype1.
Tested in ubuntu + nvidia Titan with cuda (>=10.0).
# Datasets
We conduct experiments on CIKM AnalytiCup 2017 datasets: [CIKM_AnalytiCup_Address](https://tianchi.aliyun.com/competition/entrance/231596/information) or [CIKM_Rardar](https://drive.google.com/drive/folders/1IqQyI8hTtsBbrZRRht3Es9eES_S4Qv2Y?usp=sharing)
# Training
Use any '.py' script to train these models. To train the proposed model on the radar, we can simply run the cikm_inter_dst_predrnn_run.py or cikm_dst_predrnn_run.py
You might want to change the parameter and setting, you can change the details of variable ‘args’ in each files for each model
The preprocess method and data root path can be modified in the data/data_iterator.py file
There are all trained models. You can download it following this address:[trained model](https://drive.google.com/file/d/1pnTSDoaKuKouu7y_j-QTq8dDBKVA-mPD/view)
# Evaluation
We give two approaches to evaluate our models.
The first method is to check all predictions by running the java file in the path of CIKM_Eva/src (It is faster). You need to modify some information of path and make a .jar file to run
The second method is to run the evaluation.py in the path of data_provider/CIKM/
Build docker image:
# Prediction samples
5 frames are predicted given the last 10 frames.
```sh
sudo docker build -t ida-lstm .
```
![Prediction vislazation](https://github.com/luochuyao/IDA_LSTM/blob/master/radar_res.png)
Run container:
```sh
sudo nvidia-docker run -v /data2/franch/meteotn_traindata:/data2/franch/meteotn_traindata --rm ida-lstm
```
Build singularity
```sh
sudo singularity build ida-lstm docker-daemon://ida-lstm:latest
```
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "3b85eefc-e3a2-47fb-8458-3a6d7b5f62a1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0d2542a1-bdc6-4279-9b24-549da20c4948",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3a6ef2aa-4209-4221-a801-57f56e5fb085",
"metadata": {},
"outputs": [],
"source": [
"def wrapper_train(model):\n",
" if args.pretrained_model:\n",
......@@ -115,28 +96,21 @@
"\n",
" return loss / count\n",
"\n"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "37c862ea-8a05-44bd-b840-c19ec98a4119",
"metadata": {},
"outputs": [],
"source": [
"# if args.is_training:\n",
"# wrapper_train(model)\n",
"# else:\n",
"# wrapper_test(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b87b38b7-633b-4310-8c5c-f3addc992459",
"metadata": {},
],
"outputs": [],
"source": []
"metadata": {}
}
],
"metadata": {
......@@ -160,4 +134,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
\ No newline at end of file
%% Cell type:code id:3b85eefc-e3a2-47fb-8458-3a6d7b5f62a1 tags:
``` python
```
%% Cell type:code id:0d2542a1-bdc6-4279-9b24-549da20c4948 tags:
``` python
```
%% Cell type:code id:3a6ef2aa-4209-4221-a801-57f56e5fb085 tags:
%% Cell type:code id: tags:
``` python
def wrapper_train(model):
if args.pretrained_model:
model.load(args.pretrained_model)
# load data
# train_input_handle, test_input_handle = datasets_factory.data_provider(
# args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# seq_length=args.total_length, is_training=True)
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, real_input_flag, args, itr)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
best_mse = valid_mse
best_iter = itr
tolerate = 0
model.save()
else:
tolerate = tolerate + 1
if tolerate == limit:
model.load()
test_mse = wrapper_test(model)
print("the best valid mse is:", str(best_mse))
print("the test mse is ", str(test_mse))
break
def wrapper_valid(model):
loss = 0
count = 0
index = 1
flag = True
# img_mse, ssim = [], []
# for i in range(args.total_length - args.input_length):
# img_mse.append(0)
# ssim.append(0)
real_input_flag = np.zeros(
(
args.batch_size,
args.total_length - args.input_length - 1,
args.img_width // args.patch_size,
args.img_width // args.patch_size,
args.patch_size ** 2 * args.img_channel,
)
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
loss = loss + mse
count = count + 1
if b_cup == args.batch_size - 1:
pass
else:
flag = False
return loss / count
```
%% Cell type:code id:37c862ea-8a05-44bd-b840-c19ec98a4119 tags:
%% Cell type:code id: tags:
``` python
# if args.is_training:
# wrapper_train(model)
# else:
# wrapper_test(model)
```
%% Cell type:code id:b87b38b7-633b-4310-8c5c-f3addc992459 tags:
``` python
```
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment