Commit 1011705c authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

no dst

parent 0829ea03
# data_path = '/storage/DSIP/DST/data/'
data_path = '/home/marco/projects/projects_data/DST/data/'
data_path = '/storage/DSIP/DST/data/'
#data_path = '/home/marco/projects/projects_data/DST/data/'
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -23,7 +23,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
BEFORE = 12
AFTER = 12
AFTER = 6
dst_data = pd.read_pickle(os.path.join(data_path,'dst.pkl'))
dst_data['ora_round'] = dst_data.ora.apply(lambda x:int(x.split(':')[0]))
......@@ -294,14 +294,14 @@ for epoch in range(num_epochs):
out_r = dst_net(data_in_scaled[last_train:].to(device).float())
tp = 11
tp = 5
dst_levels = [-20,-50,-100]
truth = data_out[last_train:].cpu().detach().numpy().copy()
out = mm_scaler_out.inverse_transform(out_r.cpu().clone()).detach().numpy()
for ll in range(12):
for ll in range(6):
print(ll, np.sqrt(mean_squared_error(truth[:,ll], out[:,ll])))
truth[np.where(truth >= dst_levels[0])] = 0
......@@ -314,7 +314,7 @@ for epoch in range(num_epochs):
out[np.where((out < dst_levels[1]) & (out >= dst_levels[2]))] = 2
out[np.where((out < dst_levels[2]))] = 3
print(confusion_matrix(truth[:,11], out[:,11])/confusion_matrix(truth[:,11], out[:,11]).sum(axis=1)[:, None])
print(confusion_matrix(truth[:,tp], out[:,tp])/confusion_matrix(truth[:,tp], out[:,tp]).sum(axis=1)[:, None])
# np.random.shuffle(ixs_tr2)
# ixs_tr2a = ixs_tr2[:10000]
......@@ -342,8 +342,8 @@ for epoch in range(num_epochs):
# dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_96_8_ns.pth')
torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_96_8_ns_A6.pth')
np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_96_8_ns.txt', history_tr)
np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_96_8_ns.txt', history_valid)
np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_96_8_ns.txt', history_ts)
np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_96_8_ns_A6.txt', history_tr)
np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_96_8_ns_A6.txt', history_valid)
np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_96_8_ns_A6.txt', history_ts)
......@@ -269,21 +269,21 @@ for epoch in range(num_epochs):
dst_net.eval()
data_out_scaled_loss = mm_scaler_out.inverse_transform(data_out_scaled.clone())
outputs = dst_net(data_in_scaled[:last_train].to(device).float())
outputs[:, 0] = outputs[:, 0] + data_in_scaled[:last_train][:,-1,-1]
outputs[:, 0] = outputs[:, 0] + data_in_scaled[:last_train][:,-1,-1].to(device).float()
for i in range(1, 12):
outputs[:, i] += outputs[:, i-1]
loss_tr = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(outputs.cpu().clone()).to(device), data_out_scaled_loss[:last_train].to(device).float()).item())
loss_mae_tr = loss_f(mm_scaler_out.inverse_transform(outputs.cpu().clone()).to(device), data_out_scaled_loss[:last_train].to(device).float()).item()
outputs = dst_net(data_in_scaled[ixs_valid].to(device).float())
outputs[:, 0] = outputs[:, 0] + data_in_scaled[ixs_valid][:,-1,-1]
outputs[:, 0] = outputs[:, 0] + data_in_scaled[ixs_valid][:,-1,-1].to(device).float()
for i in range(1, 12):
outputs[:, i] += outputs[:, i-1]
loss_valid = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(outputs.cpu().clone()).to(device), data_out_scaled_loss[ixs_valid].to(device).float()).item())
loss_mae_valid = loss_f(mm_scaler_out.inverse_transform(outputs.cpu().clone()).to(device), data_out_scaled_loss[ixs_valid].to(device).float()).item()
outputs = dst_net(data_in_scaled[ixs_test].to(device).float())
outputs[:, 0] = outputs[:, 0] + data_in_scaled[ixs_test][:,-1,-1]
outputs[:, 0] = outputs[:, 0] + data_in_scaled[ixs_test][:,-1,-1].to(device).float()
for i in range(1, 12):
outputs[:, i] += outputs[:, i-1]
loss_ts = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(outputs.cpu().clone()).to(device), data_out_scaled_loss[ixs_test].to(device).float()).item())
......@@ -308,7 +308,7 @@ for epoch in range(num_epochs):
dst_levels = [-20,-50,-100]
truth = data_out[last_train:].cpu().detach().numpy().copy()
out_r[:, 0] = out_r[:, 0] + data_in_scaled[last_train:][:,-1,-1]
out_r[:, 0] = out_r[:, 0] + data_in_scaled[last_train:][:,-1,-1].to(device).float()
for i in range(1, 12):
out_r[:, i] += out_r[:, i-1]
out = mm_scaler_out.inverse_transform(out_r.cpu().clone()).detach().numpy()
......
......@@ -159,13 +159,15 @@ class DSTnett(nn.Module):
def __init__(self,):
super().__init__()
self.hidden1 = 256
self.hidden1 = 516
self.linear_o_1 = nn.Linear(8, self.hidden1)
self.ln1 = nn.LayerNorm(self.hidden1)
self.linear_o_2 = nn.Linear(self.hidden1, self.hidden1*2)
self.linear_o_3 = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3a = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3b = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3b1 = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3b2 = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3c = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3d = nn.Linear(self.hidden1*2, self.hidden1)
self.linear_o_3e = nn.Linear(self.hidden1, self.hidden1)
......@@ -185,6 +187,10 @@ class DSTnett(nn.Module):
x = F.dropout(x, 0.2, training=self.training)
x = F.relu(self.linear_o_3b(x))
x = F.dropout(x, 0.2, training=self.training)
x = F.relu(self.linear_o_3b1(x))
x = F.dropout(x, 0.2, training=self.training)
x = F.relu(self.linear_o_3b2(x))
x = F.dropout(x, 0.2, training=self.training)
x = F.relu(self.linear_o_3c(x))
x = F.dropout(x, 0.2, training=self.training)
x = F.relu(self.linear_o_3d(x))
......@@ -201,7 +207,7 @@ dst_net = DSTnett().to(device)
print(dst_net)
num_epochs = 10000
lr = 1e-4
lr = 1e-5
optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr, weight_decay=1e-5)
history_tr = np.zeros((num_epochs, 2))
......
......@@ -189,12 +189,14 @@ class DSTnet(nn.Module):
self.after = after
self.n_out_i = n_out_i
self.hidden1 = 256
self.lstm = nn.LSTM(self.nvars, self.n_out_i, self.nhidden_i, batch_first=True)
self.first_merged_layer = self.n_out_i * self.before
self.bn1 = nn.BatchNorm1d(num_features=self.first_merged_layer)
# self.bn1 = nn.LayerNorm(self.first_merged_layer)
self.linear_o_1 = nn.Linear(self.first_merged_layer, self.nhidden_o)
self.linear_o_1 = nn.Linear(self.first_merged_layer, self.hidden1)
self.linear_o_2 = nn.Linear(self.hidden1, self.hidden1*2)
self.linear_o_3 = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3a = nn.Linear(self.hidden1*2, self.hidden1*2)
......@@ -202,7 +204,7 @@ class DSTnet(nn.Module):
self.linear_o_3c = nn.Linear(self.hidden1*2, self.hidden1*2)
self.linear_o_3d = nn.Linear(self.hidden1*2, self.hidden1)
self.linear_o_3e = nn.Linear(self.hidden1, self.hidden1)
self.linear_o_4 = nn.Linear(self.hidden1, 1)
self.linear_o_4 = nn.Linear(self.hidden1, self.after)
def init_hidden(self, batch_size):
hidden = torch.randn(self.nhidden_i, batch_size, self.n_out_i).to(device)
......@@ -346,8 +348,8 @@ for epoch in range(num_epochs):
# dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_96_8_ns.pth')
#torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_96_8_ns.pth')
np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_96_8_ns.txt', history_tr)
np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_96_8_ns.txt', history_valid)
np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_96_8_ns.txt', history_ts)
#np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_96_8_ns.txt', history_tr)
#np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_96_8_ns.txt', history_valid)
#np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_96_8_ns.txt', history_ts)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment