Commit 0a19515b authored by Marco Cristoforetti's avatar Marco Cristoforetti
Browse files

regr

parent 1af9d81f
data_path = '/home/marco/projects/projects_data/DST/data'
data_path = '/storage/DSIP/DST/data/'
......@@ -158,7 +158,7 @@ ixs_tr2 = np.where((mm_scaler_out.inverse_transform(data_out_scaled[:last_train]
BATCH_SIZE=256
np.random.shuffle(ixs_tr2)
ixs_tr2a = ixs_tr2[:10000]
ixs_tr2a = ixs_tr2#[:10000]
ixs_tr = list(ixs_tr1) + list(ixs_tr2a)
dst_min = data_out[ixs_tr].min(axis=1).values.flatten()
......@@ -174,7 +174,7 @@ def fix_weight(dst_v):
if dst_v - b[pos] < 0:
pos = pos-1
return np.sqrt(w[pos]/h.max())
# return w[pos]/h.max()
# return w[pos]/h.max()
fix_weight_v = np.vectorize(fix_weight)
weights = fix_weight_v(dst_min)
......@@ -315,27 +315,34 @@ for epoch in range(num_epochs):
print(confusion_matrix(truth[:,11], out[:,11])/confusion_matrix(truth[:,11], out[:,11]).sum(axis=1)[:, None])
np.random.shuffle(ixs_tr2)
ixs_tr2a = ixs_tr2[:10000]
ixs_tr = list(ixs_tr1) + list(ixs_tr2a)
dst_min = data_out[ixs_tr].min(axis=1).values.flatten()
bins = [dst_min.min() - 10] + list(np.arange(-300, dst_min.max() + 10, 10))
h, b = np.histogram(dst_min, bins=bins)
if len(np.argwhere(h == 0)) > 0:
bins = np.delete(bins, np.argwhere(h == 0)[0] + 1)
h, b = np.histogram(dst_min, bins=bins)
w = h.max()/h
# np.random.shuffle(ixs_tr2)
# ixs_tr2a = ixs_tr2[:10000]
# ixs_tr = list(ixs_tr1) + list(ixs_tr2a)
# dst_min = data_out[ixs_tr].min(axis=1).values.flatten()
# bins = [dst_min.min() - 10] + list(np.arange(-300, dst_min.max() + 10, 10))
# h, b = np.histogram(dst_min, bins=bins)
# if len(np.argwhere(h == 0)) > 0:
# bins = np.delete(bins, np.argwhere(h == 0)[0] + 1)
# h, b = np.histogram(dst_min, bins=bins)
# w = h.max()/h
# def fix_weight(dst_v):
# pos = np.argwhere(np.abs(b - dst_v) == np.abs((b - dst_v)).min())[0,0]
# if dst_v - b[pos] < 0:
# pos = pos-1
# #return np.sqrt(w[pos]/h.max())
# return w[pos]/h.max()
fix_weight_v = np.vectorize(fix_weight)
weights = fix_weight_v(dst_min)
# fix_weight_v = np.vectorize(fix_weight)
# weights = fix_weight_v(dst_min)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= len(dst_min))
dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
# sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= len(dst_min))
# dataset_tr = Dataset(data_in_scaled[ixs_tr], data_out_scaled[ixs_tr])
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
torch.save(dst_net.state_dict(), 'models/dst_reg_64_16.pth')
torch.save(dst_net.state_dict(), 'models/dst_reg_64_16_sqrt_all.pth')
np.savetxt('hist/history_tr_rmse_mae_reg_64_16.txt', history_tr)
np.savetxt('hist/history_valid_rmse_mae_reg_64_16.txt', history_valid)
np.savetxt('hist/history_ts_rmse_mae_reg_64_16.txt', history_ts)
np.savetxt('hist/history_tr_rmse_mae_reg_64_16_sqrt_all.txt', history_tr)
np.savetxt('hist/history_valid_rmse_mae_reg_64_16_sqrt_all.txt', history_valid)
np.savetxt('hist/history_ts_rmse_mae_reg_64_16_sqrt_all.txt', history_ts)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment