Commit c94a8df4 authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

Merge branch 'master' of gitlab.fbk.eu:mcristofo/DST

parents 9f302acc ef043c24
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -11,6 +11,8 @@ import torch.utils.data as utils_data ...@@ -11,6 +11,8 @@ import torch.utils.data as utils_data
import torch.nn.functional as F import torch.nn.functional as F
import datetime import datetime
from sklearn.metrics import confusion_matrix, matthews_corrcoef
torch.manual_seed(21894) torch.manual_seed(21894)
np.random.seed(21894) np.random.seed(21894)
...@@ -150,7 +152,13 @@ np.random.shuffle(ixs_valid_test) ...@@ -150,7 +152,13 @@ np.random.shuffle(ixs_valid_test)
ixs_valid = ixs_valid_test[::2] ixs_valid = ixs_valid_test[::2]
ixs_test = ixs_valid_test[1::2] ixs_test = ixs_valid_test[1::2]
dst_min = data_out[:last_train].min(axis=1).values.flatten() ixs_tr1 = np.where((mm_scaler_out.inverse_transform(data_out_scaled[:last_train].clone()).min(axis=1)[0]).numpy()<-20)[0]
ixs_tr2 = np.where((mm_scaler_out.inverse_transform(data_out_scaled[:last_train].clone()).min(axis=1)[0]).numpy()>=-20)[0]
np.random.shuffle(ixs_tr2)
ixs_tr2 = ixs_tr2[:10000]
ixs_tr = list(ixs_tr1) + list(ixs_tr2)
dst_min = data_out[ixs_tr].min(axis=1).values.flatten()
bins = [dst_min.min() - 10] + list(np.arange(-300, dst_min.max() + 10, 10)) bins = [dst_min.min() - 10] + list(np.arange(-300, dst_min.max() + 10, 10))
h, b = np.histogram(dst_min, bins=bins) h, b = np.histogram(dst_min, bins=bins)
...@@ -174,7 +182,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l ...@@ -174,7 +182,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l
BATCH_SIZE=256 BATCH_SIZE=256
dataset_tr = Dataset(data_in_scaled[:last_train], data_out_scaled[:last_train], data_out_c[:last_train], weights) dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr], data_out_c[ixs_tr], weights)
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler) # data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=True) data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=True)
...@@ -200,7 +208,7 @@ class DSTnet(nn.Module): ...@@ -200,7 +208,7 @@ class DSTnet(nn.Module):
self.linear_o_4_c = nn.Linear(self.nhidden_o // 2, self.after * 4) self.linear_o_4_c = nn.Linear(self.nhidden_o // 2, self.after * 4)
self.linear_o_4_r = nn.Linear(4, 16) self.linear_o_4_r = nn.Linear(4, 16)
self.linear_o_4b_r = nn.Linear(4, 4) self.linear_o_4b_r = nn.Linear(16, 16)
self.linear_o_5_r = nn.Linear(16, 1) self.linear_o_5_r = nn.Linear(16, 1)
...@@ -226,39 +234,43 @@ class DSTnet(nn.Module): ...@@ -226,39 +234,43 @@ class DSTnet(nn.Module):
x = F.dropout(x, 0.2, training=self.training) x = F.dropout(x, 0.2, training=self.training)
x1 = self.linear_o_4_c(x) x1 = self.linear_o_4_c(x)
x2 = F.relu(self.linear_o_4_r(x1.view(-1, 4))) x2 = F.relu(self.linear_o_4_r(x1.view(-1, 4)))
# x2 = F.relu(self.linear_o_4b_r(x2)) x2 = F.relu(self.linear_o_4b_r(x2))
x2 = self.linear_o_5_r(x2) x2 = self.linear_o_5_r(x2)
x2 = x2.reshape(x0.size(0), self.after) x2 = x2.reshape(x0.size(0), self.after)
x1 = x1.reshape(x0.size(0) * self.after, 4) x1 = x1.reshape(x0.size(0) * self.after, 4)
return x2, x1 return x2, x1
aa = data_out_c[:last_train] aa = data_out_c[ixs_tr]
weights_c = torch.tensor([len(aa[aa==0])/len(aa[aa==0]), len(aa[aa==1])/len(aa[aa==1]), len(aa[aa==0])/len(aa[aa==2]), len(aa[aa==0])/len(aa[aa==3])]).to(device).sqrt() weights_c = torch.tensor([len(aa[aa==1])/len(aa[aa==0]), len(aa[aa==1])/len(aa[aa==1]), len(aa[aa==0])/len(aa[aa==2]), len(aa[aa==0])/len(aa[aa==3])]).to(device)#.sqrt()
loss_f = nn.L1Loss() loss_f = nn.L1Loss()
loss_mse = nn.MSELoss(reduction='none') loss_mse = nn.MSELoss(reduction='none')
loss_fc= nn.CrossEntropyLoss() #loss_fc= nn.CrossEntropyLoss()
# loss_fc= nn.CrossEntropyLoss(weight = weights_c) loss_fc= nn.CrossEntropyLoss(weight = weights_c)
nhidden_i = 2 nhidden_i = 2
#nhidden_o = 96
#n_out_i = 8
nhidden_o = 96 nhidden_o = 96
n_out_i = 8 n_out_i = 12
before = BEFORE before = BEFORE
nvars = data_in_scaled.shape[-1] nvars = data_in_scaled.shape[-1]
dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device) dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device)
print(dst_net) print(dst_net)
num_epochs = 2000 num_epochs = 10000
lr = 1e-4 lr = 1e-5
optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr)#, weight_decay=1e-5) optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr, weight_decay=1e-5)
history_tr = np.zeros((num_epochs, 3)) history_tr = np.zeros((num_epochs, 3))
history_valid = np.zeros((num_epochs, 3)) history_valid = np.zeros((num_epochs, 3))
history_ts = np.zeros((num_epochs, 3)) history_ts = np.zeros((num_epochs, 3))
np.set_printoptions(suppress=True, precision=3)
for epoch in range(num_epochs): for epoch in range(num_epochs):
start_time = time.time() start_time = time.time()
...@@ -304,10 +316,18 @@ for epoch in range(num_epochs): ...@@ -304,10 +316,18 @@ for epoch in range(num_epochs):
history_ts[epoch] = [loss_ts, loss_mae_ts, loss_c_ts] history_ts[epoch] = [loss_ts, loss_mae_ts, loss_c_ts]
epoch_time = time.time() - start_time epoch_time = time.time() - start_time
if (epoch % 10 == 0): if (epoch % 100 == 0):
print('Epoch %d time = %.2f, tr_rmse = %0.5f, val_rmse = %0.5f, ts_rmse = %0.5f, tr_c = %.5f, val_c = %.5f, ts_c = %.5f' % print('Epoch %d time = %.2f, tr_rmse = %0.5f, val_rmse = %0.5f, ts_rmse = %0.5f, tr_c = %.5f, val_c = %.5f, ts_c = %.5f' %
(epoch, epoch_time, loss_tr, loss_val, loss_ts, loss_c_tr, loss_c_val, loss_c_ts)) (epoch, epoch_time, loss_tr, loss_val, loss_ts, loss_c_tr, loss_c_val, loss_c_ts))
if (epoch % 100 == 0):
out_r, out_c = dst_net(data_in_scaled[last_train:].to(device).float())
tp = 11
out_cc = F.softmax(out_c, dim=1)
out_cc = out_cc.detach().cpu().numpy()
print(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp])/(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp]).sum(axis=1)[:, None]))
torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class_nsc.pth')) torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class_nsc.pth'))
...@@ -319,18 +339,13 @@ dst_net.eval() ...@@ -319,18 +339,13 @@ dst_net.eval()
out_r, out_c = dst_net(data_in_scaled[last_train:].to(device).float()) out_r, out_c = dst_net(data_in_scaled[last_train:].to(device).float())
out_cc = F.softmax(out_c, dim=1)
out_cc = out_cc.detach().cpu().numpy()
from sklearn.metrics import confusion_matrix, matthews_corrcoef
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
np.set_printoptions(suppress=True, precision=3) out_cc = F.softmax(out_c, dim=1)
out_cc = out_cc.detach().cpu().numpy()
tp = 11
print(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp])/(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp]).sum(axis=1)[:, None])) print(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp])/(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp]).sum(axis=1)[:, None]))
dst_levels = [-20,-50,-100]
truth = data_out[last_train:].cpu().detach().numpy().copy() truth = data_out[last_train:].cpu().detach().numpy().copy()
out = mm_scaler_out.inverse_transform(out_r.cpu().clone()).detach().numpy() out = mm_scaler_out.inverse_transform(out_r.cpu().clone()).detach().numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment