Commit 1c4c5799 authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

output

parent 56f760c8
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -183,6 +183,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l ...@@ -183,6 +183,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l
dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr]) dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler) data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
#data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=True)
class DSTnet(nn.Module): class DSTnet(nn.Module):
def __init__(self, nvars, nhidden_i, nhidden_o, n_out_i, before, after): def __init__(self, nvars, nhidden_i, nhidden_o, n_out_i, before, after):
...@@ -232,9 +233,9 @@ loss_f = nn.L1Loss() ...@@ -232,9 +233,9 @@ loss_f = nn.L1Loss()
loss_mse = nn.MSELoss() loss_mse = nn.MSELoss()
nhidden_i = 1 nhidden_i = 2
nhidden_o = 64 nhidden_o = 96
n_out_i = 16 n_out_i = 8
before = BEFORE before = BEFORE
nvars = data_in_scaled.shape[-1] nvars = data_in_scaled.shape[-1]
...@@ -341,8 +342,8 @@ for epoch in range(num_epochs): ...@@ -341,8 +342,8 @@ for epoch in range(num_epochs):
# dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr]) # dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler) # data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_64_16_all.pth') torch.save(dst_net.state_dict(), '/home/mcristofo/DST/models/dst_reg_96_8_ns.pth')
np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_64_16_all.txt', history_tr) np.savetxt('/home/mcristofo/DST/hist/history_tr_rmse_mae_reg_96_8_ns.txt', history_tr)
np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_64_16_all.txt', history_valid) np.savetxt('/home/mcristofo/DST/hist/history_valid_rmse_mae_reg_96_8_ns.txt', history_valid)
np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_64_16_all.txt', history_ts) np.savetxt('/home/mcristofo/DST/hist/history_ts_rmse_mae_reg_96_8_ns.txt', history_ts)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment