Commit a65286f9 authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

diclub

parent 66938483
data_path = '/home/marco/projects/projects_data/DST/data/' data_path = '/storage/DSIP/DST/data/'
\ No newline at end of file
...@@ -11,8 +11,6 @@ dependencies: ...@@ -11,8 +11,6 @@ dependencies:
- python>=3.7 - python>=3.7
- scikit-learn - scikit-learn
- jupyterlab - jupyterlab
- pytorch>=1.0.0+cpu - pytorch>=1.0.0
- torchvision>=0.2.0+cpu - torchvision>=0.2.0
- seaborn - seaborn
- pip
- mplhep
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import sys; sys.path.append('../DST') import sys; sys.path.append('/home/mcristofo/DST/DST')
import os import os
from DST.config import data_path from DST.config import data_path
import pandas as pd import pandas as pd
...@@ -15,6 +15,7 @@ torch.manual_seed(21894) ...@@ -15,6 +15,7 @@ torch.manual_seed(21894)
np.random.seed(21894) np.random.seed(21894)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
BEFORE = 12 BEFORE = 12
AFTER = 12 AFTER = 12
...@@ -196,8 +197,7 @@ class DSTnet(nn.Module): ...@@ -196,8 +197,7 @@ class DSTnet(nn.Module):
self.linear_o_3 = nn.Linear(self.nhidden_o, self.nhidden_o // 2) self.linear_o_3 = nn.Linear(self.nhidden_o, self.nhidden_o // 2)
self.linear_o_4 = nn.Linear(self.nhidden_o // 2, self.after) self.linear_o_4 = nn.Linear(self.nhidden_o // 2, self.after)
self.linear_o_3_c = nn.Linear(self.after, self.after*2) self.linear_o_4_c = nn.Linear(self.nhidden_o // 2, self.after*4)
self.linear_o_4_c = nn.Linear(self.after*2, self.after*4)
def init_hidden(self, batch_size): def init_hidden(self, batch_size):
...@@ -220,11 +220,9 @@ class DSTnet(nn.Module): ...@@ -220,11 +220,9 @@ class DSTnet(nn.Module):
x = F.relu(self.linear_o_3(x)) x = F.relu(self.linear_o_3(x))
x = F.dropout(x, 0.2, training=self.training) x = F.dropout(x, 0.2, training=self.training)
x = self.linear_o_4(x) x1 = self.linear_o_4(x)
x1 = x
x2 = F.relu(self.linear_o_3_c(x)) x2 = self.linear_o_4_c(x)
x2 = self.linear_o_4_c(x2)
x2 = x2.reshape(x0.size(0) * self.after, 4) x2 = x2.reshape(x0.size(0) * self.after, 4)
return x1, x2 return x1, x2
...@@ -235,7 +233,7 @@ weights_c = torch.tensor([len(aa[aa==0])/len(aa[aa==0]), len(aa[aa==0])/len(aa[a ...@@ -235,7 +233,7 @@ weights_c = torch.tensor([len(aa[aa==0])/len(aa[aa==0]), len(aa[aa==0])/len(aa[a
loss_f = nn.L1Loss() loss_f = nn.L1Loss()
loss_mse = nn.MSELoss() loss_mse = nn.MSELoss()
# loss_fc= nn.CrossEntropyLoss() #loss_fc= nn.CrossEntropyLoss()
loss_fc= nn.CrossEntropyLoss(weight = weights_c) loss_fc= nn.CrossEntropyLoss(weight = weights_c)
nhidden_i = 2 nhidden_i = 2
...@@ -247,9 +245,9 @@ nvars = data_in_scaled.shape[-1] ...@@ -247,9 +245,9 @@ nvars = data_in_scaled.shape[-1]
dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device) dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device)
print(dst_net) print(dst_net)
num_epochs = 10 num_epochs = 2000
lr = 1e-4 lr = 1e-4
optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr, weight_decay=1e-5) optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr)#, weight_decay=1e-5)
history_tr = np.zeros((num_epochs, 3)) history_tr = np.zeros((num_epochs, 3))
history_valid = np.zeros((num_epochs, 3)) history_valid = np.zeros((num_epochs, 3))
...@@ -298,8 +296,51 @@ for epoch in range(num_epochs): ...@@ -298,8 +296,51 @@ for epoch in range(num_epochs):
history_ts[epoch] = [loss_ts, loss_mae_ts, loss_c_ts] history_ts[epoch] = [loss_ts, loss_mae_ts, loss_c_ts]
epoch_time = time.time() - start_time epoch_time = time.time() - start_time
if (epoch % 1 == 0): if (epoch % 10 == 0):
print('Epoch %d time = %.2f, tr_rmse = %0.5f, val_rmse = %0.5f, ts_rmse = %0.5f, tr_c = %.5f, val_c = %.5f, ts_c = %.5f' % print('Epoch %d time = %.2f, tr_rmse = %0.5f, val_rmse = %0.5f, ts_rmse = %0.5f, tr_c = %.5f, val_c = %.5f, ts_c = %.5f' %
(epoch, epoch_time, loss_tr, loss_val, loss_ts, loss_c_tr, loss_c_val, loss_c_ts)) (epoch, epoch_time, loss_tr, loss_val, loss_ts, loss_c_tr, loss_c_val, loss_c_ts))
\ No newline at end of file torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class.pth'))
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_tr_regr_class.txt'), history_tr)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_valid_regr_class.txt'), history_valid)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_ts_regr_class.txt'), history_ts)
dst_net.eval()
out_r, out_c = dst_net(data_in_scaled[last_train:].to(device).float())
out_cc = F.softmax(out_c, dim=1)
out_cc = out_cc.detach().cpu().numpy()
from sklearn.metrics import confusion_matrix, matthews_corrcoef
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
np.set_printoptions(suppress=True, precision=3)
tp = 11
print(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp])/(confusion_matrix(data_out_c[last_train:][:,tp].cpu().detach().numpy(), out_cc.argmax(axis=1).reshape(-1,12)[:,tp]).sum(axis=1)[:, None]))
dst_levels = [-20,-50,-100]
truth = data_out[last_train:].cpu().detach().numpy().copy()
out = mm_scaler_out.inverse_transform(out_r.cpu().clone()).detach().numpy()
for i in range(12):
print(i, np.sqrt(mean_squared_error(truth[:,i], out[:,i])))
truth[np.where(truth >= dst_levels[0])] = 0
truth[np.where((truth < dst_levels[0]) & (truth >= dst_levels[1]))] = 1
truth[np.where((truth < dst_levels[1]) & (truth >= dst_levels[2]))] = 2
truth[np.where((truth < dst_levels[2]))] = 3
out[np.where(out >= dst_levels[0])] = 0
out[np.where((out < dst_levels[0]) & (out >= dst_levels[1]))] = 1
out[np.where((out < dst_levels[1]) & (out >= dst_levels[2]))] = 2
out[np.where((out < dst_levels[2]))] = 3
print(confusion_matrix(truth[:,11], out[:,11])/confusion_matrix(truth[:,11], out[:,11]).sum(axis=1)[:, None])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment