Commit a697d6aa authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

results

parent 0a19515b
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -173,8 +173,8 @@ def fix_weight(dst_v):
pos = np.argwhere(np.abs(b - dst_v) == np.abs((b - dst_v)).min())[0,0]
if dst_v - b[pos] < 0:
pos = pos-1
return np.sqrt(w[pos]/h.max())
# return w[pos]/h.max()
# return np.sqrt(w[pos]/h.max())
return w[pos]/h.max()
fix_weight_v = np.vectorize(fix_weight)
weights = fix_weight_v(dst_min)
......@@ -341,8 +341,8 @@ for epoch in range(num_epochs):
# dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
torch.save(dst_net.state_dict(), 'models/dst_reg_64_16_sqrt_all.pth')
torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_64_16_all.pth')
np.savetxt('hist/history_tr_rmse_mae_reg_64_16_sqrt_all.txt', history_tr)
np.savetxt('hist/history_valid_rmse_mae_reg_64_16_sqrt_all.txt', history_valid)
np.savetxt('hist/history_ts_rmse_mae_reg_64_16_sqrt_all.txt', history_ts)
np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_64_16_all.txt', history_tr)
np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_64_16_all.txt', history_valid)
np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_64_16_all.txt', history_ts)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment