Commit 8453acfc authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

out

parent 00da0786
data_path = '/home/marco/projects/projects_data/DST/data/'
data_path = '/storage/DSIP/DST/data/'
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -225,10 +225,10 @@ n_out_i = 8
before = BEFORE
nvars = data_in_scaled.shape[-1]
# dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device)
dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device)
print(dst_net)
num_epochs = 2000
num_epochs = 10000
lr = 1e-5
optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr)#, weight_decay=1e-5)
......
......@@ -243,8 +243,8 @@ nvars = data_in_scaled.shape[-1]
dst_net = DSTnet(nvars, nhidden_i, nhidden_o, n_out_i, before, AFTER).to(device)
print(dst_net)
num_epochs = 2000
lr = 1e-4
num_epochs = 10000
lr = 1e-5
optimizer = torch.optim.Adam(dst_net.parameters(), lr=lr, weight_decay=1e-5)
history_tr = np.zeros((num_epochs, 3))
......@@ -352,6 +352,7 @@ for epoch in range(num_epochs):
dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr], data_out_c[ixs_tr], weights)
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=True)
loss_fc= nn.CrossEntropyLoss(weight = weights_c)
torch.save(dst_net.state_dict(), os.path.join('/home/mcristofo/DST/models','dst_regr_class.pth'))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment