Commit 801ae04a authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

Merge branch 'master' of gitlab.fbk.eu:mcristofo/DST

parents a67fb39c 532dd1e2
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -128,10 +128,11 @@ data_out_c[np.where((data_out_c < dst_levels[2]))] = 3
class Dataset(utils_data.Dataset):
def __init__(self, dataset_in, dataset_out, dataset_out_c):
def __init__(self, dataset_in, dataset_out, dataset_out_c, weights):
self.dataset_in = dataset_in
self.dataset_out = dataset_out
self.dataset_out_c = dataset_out_c
self.weights = weights
def __len__(self):
return self.dataset_in.size(0)
......@@ -141,7 +142,8 @@ class Dataset(utils_data.Dataset):
din_src = self.dataset_in[idx]
dout = self.dataset_out[idx]
dout_c = self.dataset_out_c[idx]
return din_src, dout, dout_c
ww = self.weights[idx]
return din_src, dout, dout_c, ww
ixs_valid_test = np.arange(int(len_valid_test)) + last_train
np.random.shuffle(ixs_valid_test)
......@@ -172,7 +174,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l
BATCH_SIZE=256
dataset_tr = Dataset(data_in_scaled[:last_train], data_out_scaled[:last_train], data_out_c[:last_train])
dataset_tr = Dataset(data_in_scaled[:last_train], data_out_scaled[:last_train], data_out_c[:last_train], weights)
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=True)
......@@ -229,10 +231,10 @@ class DSTnet(nn.Module):
return x1, x2
aa = data_out_c[:last_train]
weights_c = torch.tensor([len(aa[aa==0])/len(aa[aa==0]), len(aa[aa==0])/len(aa[aa==1]), len(aa[aa==0])/len(aa[aa==2]), len(aa[aa==0])/len(aa[aa==3])]).to(device).sqrt()
weights_c = torch.tensor([len(aa[aa==0])/len(aa[aa==0]), len(aa[aa==1])/len(aa[aa==1]), len(aa[aa==0])/len(aa[aa==2]), len(aa[aa==0])/len(aa[aa==3])]).to(device).sqrt()
loss_f = nn.L1Loss()
loss_mse = nn.MSELoss()
loss_mse = nn.MSELoss(reduction='none')
#loss_fc= nn.CrossEntropyLoss()
loss_fc= nn.CrossEntropyLoss(weight = weights_c)
......@@ -263,14 +265,15 @@ for epoch in range(num_epochs):
x = batch[0].float().to(device)
y_r = batch[1].float().to(device)
y_c = batch[2].flatten().long().to(device)
w = batch[3].to(device)
optimizer.zero_grad()
dst_net.train()
out_r, out_c = dst_net(x)
loss_r = loss_f(out_r, y_r)
loss_c = loss_fc(out_c, y_c)
loss = loss_r + loss_c
loss = (loss_r * w).mean() + loss_c
loss.backward()
optimizer.step()
......@@ -278,17 +281,17 @@ for epoch in range(num_epochs):
dst_net.eval()
out_r, out_c = dst_net(data_in_scaled[:last_train].to(device).float())
loss_tr = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[:last_train].to(device).float()).item())
loss_tr = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[:last_train].to(device).float()).mean().item())
loss_mae_tr = loss_f(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[:last_train].to(device).float()).item()
loss_c_tr = loss_fc(out_c, data_out_c[:last_train].flatten().long().to(device)).item()
out_r, out_c = dst_net(data_in_scaled[ixs_valid].to(device).float())
loss_val = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_valid].to(device).float()).item())
loss_val = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_valid].to(device).float()).mean().item())
loss_mae_val = loss_f(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_valid].to(device).float()).item()
loss_c_val = loss_fc(out_c, data_out_c[ixs_valid].flatten().long().to(device)).item()
out_r, out_c = dst_net(data_in_scaled[ixs_test].to(device).float())
loss_ts = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_test].to(device).float()).item())
loss_ts = np.sqrt(loss_mse(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_test].to(device).float()).mean().item())
loss_mae_ts = loss_f(mm_scaler_out.inverse_transform(out_r.cpu().clone()).to(device), data_out[ixs_test].to(device).float()).item()
loss_c_ts = loss_fc(out_c, data_out_c[ixs_test].flatten().long().to(device)).item()
......@@ -302,11 +305,11 @@ for epoch in range(num_epochs):
(epoch, epoch_time, loss_tr, loss_val, loss_ts, loss_c_tr, loss_c_val, loss_c_ts))
torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class.pth'))
torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class_nsc.pth'))
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_tr_regr_class.txt'), history_tr)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_valid_regr_class.txt'), history_valid)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_ts_regr_class.txt'), history_ts)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_tr_regr_class_nsc.txt'), history_tr)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_valid_regr_class_nsc.txt'), history_valid)
np.savetxt(os.path.join('/home/mcristofo/DST/hist','history_ts_regr_class_nsc.txt'), history_ts)
dst_net.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment