Commit 44ce0f7c authored by luochuyao's avatar luochuyao
Browse files

Update the evaluation file

parent a986bc65
......@@ -5,7 +5,7 @@ import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from scipy.misc import imsave,imread
from skimage.measure import compare_ssim
evaluate_root = '/mnt/A/meteorological/2500_ref_seq/'
test_root = '/mnt/A/CIKM2017/CIKM_datasets/test/'
......@@ -279,27 +279,157 @@ def seq_hss_csi_test(test_model_list,model_names,is_java=True,is_plot=True):
return test_model_hss,test_model_csi
def all_seq_test():
def eval_test(true_fold,pred_fold,eval_type):
res = 0
# valid_root_path = '/home/ices/PycharmProject/IDAST_LSTM/data_provider/valid_test.txt'
# with open(valid_root_path) as f:
# sample_indexes = f.read().split('\n')[:-1]
sample_indexes = list(range(1,4001,1))
for index in sample_indexes:
true_current_fold = true_fold+'sample_'+str(index)+'/'
pre_current_fold = pred_fold+'sample_'+str(index)+'/'
pred_imgs = []
true_imgs = []
for t in range(6, 16, 1):
pred_path = pre_current_fold+'img_'+str(t)+'.png'
true_path = true_current_fold+'img_'+str(t)+'.png'
pred_img = imread(pred_path)
true_img = imread(true_path)
pred_img = pred_img.astype(np.float32)
true_img = true_img.astype(np.float32)
pred_imgs.append(pred_img)
true_imgs.append(true_img)
pred_imgs = np.array(pred_imgs)
true_imgs = np.array(true_imgs)
# pred_imgs = pixel_to_dBZ(pred_imgs)
# true_imgs = pixel_to_dBZ(true_imgs)
pred_imgs = pred_imgs.astype(np.float)
true_imgs = true_imgs.astype(np.float)
if eval_type == 'mse':
# sample_res = np.square(pred_imgs - true_imgs).mean()
sample_res = np.mean(np.square(pred_imgs - true_imgs))
elif eval_type == 'mae':
# sample_res = np.abs(pred_imgs - true_imgs).mean()
sample_res = np.mean(np.abs(pred_imgs - true_imgs))
elif eval_type == 'ssim':
sample_res = 0
for t in range(10):
ssim = compare_ssim(pred_imgs[t],true_imgs[t])
sample_res = sample_res+ssim
sample_res = sample_res/10.0
elif eval_type == 'rmse':
sample_res = np.sqrt(np.mean(np.square(pred_imgs - true_imgs)))
res = res+sample_res
res = res/len(sample_indexes)
return res
def all_test():
true_test_root = '/mnt/A/CIKM2017/CIKM_datasets/test/'
# true_validation_root = '/mnt/A/CIKM2017/CIKM_datasets/validation/'
pred_root = '/mnt/A/meteorological/2500_ref_seq/'
test_model_list = [
"CIKM_convlstm",
"CIKM_ConvGRU_test",
"CIKM_TrajGRU_test",
"CIKM_predrnn",
"CIKM_predrnn_plus",
"e3d_s_lstm_test_",
"CIKM_MIM_test",
"CIKM_dst_predrnn",
"CIKM_inter_dst_predrnn_r2"
]
# test_model_list = [
# "CIKM_convlstm",
# "CIKM_ConvGRU_test",
# "CIKM_TrajGRU_test",
# "CIKM_predrnn",
# "CIKM_sst_predrnn",
# "CIKM_cst_predrnn",
# "CIKM_dst_predrnn",
# ]
# test_model_list = [
# "CIKM_convlstm",
# "CIKM_convlstm_test_r1",
# "CIKM_convlstm_test_r2",
# "CIKM_convlstm_test_r3",
# "CIKM_convlstm_test_r4"
# ]
# test_model_list = [
# "CIKM_predrnn",
# "CIKM_predrnn_r1",
# "CIKM_predrnn_r2",
# "CIKM_predrnn_r3",
# "CIKM_predrnn_r4"
# ]
# test_model_list = [
# "CIKM_predrnn_plus",
# "e3d_s_lstm_test_",
# "CIKM_MIM_test",
# "CIKM_predrnn_plus_r1",
# "CIKM_predrnn_plus_r2",
# "CIKM_predrnn_plus_r3",
# "CIKM_predrnn_plus_r4"
# ]
# test_model_list = [
# "CIKM_dst_predrnn",
# "CIKM_inter_dst_predrnn_r2"
# "CIKM_inter_dst_predrnn_r1",
# "CIKM_inter_dst_predrnn_r2",
# "CIKM_inter_dst_predrnn_r3",
# "CIKM_inter_dst_predrnn_r4"
# ]
# model_names = ['ConvLSTM',
# 'ConvGRU',
# 'TrajGRU',
# 'PredRNN',
# 'PredRNN++',
# 'E3D-LSTM',
# 'MIM',
# 'DA-LSTM',
# 'IDA-LSTM']
# test_model_mse = {}
# for model in test_model_list:
# mse = eval_test(true_test_root, pred_root + model + '/', 'weight_mse')
# test_model_mse[model] = mse
# print('weight_mse model is:', model)
# print(test_model_mse[model])
test_model_mse = {}
for model in test_model_list:
mse = eval_test(true_test_root, pred_root + model + '/', 'rmse')
test_model_mse[model] = mse
print('mse model is:', model)
print(test_model_mse[model])
# test_model_mae = {}
# for model in test_model_list:
# mae = eval_test(true_test_root, pred_root + model + '/', 'mae')
# test_model_mae[model] = mae
# print('mae model is:', model)
# print(test_model_mae[model])
#
# print('ssim')
# test_model_ssim = {}
# for id, model in enumerate(test_model_list):
# ssim = eval_test(true_test_root, pred_root + model + '/', 'ssim')
# test_model_ssim[model] = ssim
# print('ssim model is:', model)
# print(test_model_ssim[model])
# print('*' * 80)
def all_seq_test():
test_model_list = [
"CIKM_convlstm",
"CIKM_ConvGRU_test",
"CIKM_TrajGRU_test",
"CIKM_predrnn",
"CIKM_predrnn_plus",
"e3d_s_lstm_test_",
"CIKM_MIM_test",
"CIKM_dst_predrnn",
"CIKM_inter_dst_predrnn_r2"
]
model_names = ['ConvLSTM',
'ConvGRU',
'TrajGRU',
'PredRNN',
'PredRNN++',
'E3D-LSTM',
'MIM',
'DA-LSTM',
'IDA-LSTM']
# test_model_list = [
# "CIKM_predrnn",
......@@ -316,8 +446,8 @@ def all_seq_test():
# "CIKM_convlstm_test_r4",
# "CIKM_predrnn",
# "CIKM_predrnn_r4",
"CIKM_predrnn_plus",
"CIKM_predrnn_plus_r4",
# "CIKM_predrnn_plus",
# "CIKM_predrnn_plus_r4",
# "CIKM_dst_predrnn",
# "CIKM_inter_dst_predrnn_r2"
]
......@@ -326,12 +456,15 @@ def all_seq_test():
# 'IConvLSTM',
# 'PredRNN',
# 'IPredRNN',
'PredRNN++',
'IPredRNN++',
# 'PredRNN++',
# 'IPredRNN++',
# 'DA-LSTM',
# 'IDA-LSTM'
]
seq_hss_csi_test(test_model_list, model_names, is_java=True, is_plot=True)
if __name__ == '__main__':
all_seq_test()
\ No newline at end of file
all_test()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment