Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Cristoforetti
DST
Commits
801ae04a
Commit
801ae04a
authored
Feb 07, 2021
by
Marco Cristoforetti
Browse files
Merge branch 'master' of gitlab.fbk.eu:mcristofo/DST
parents
a67fb39c
532dd1e2
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
hist/history_tr_regr_class_nsc.txt
0 → 100644
View file @
801ae04a
This diff is collapsed.
Click to expand it.
hist/history_ts_regr_class_nsc.txt
0 → 100644
View file @
801ae04a
This diff is collapsed.
Click to expand it.
hist/history_valid_regr_class_nsc.txt
0 → 100644
View file @
801ae04a
This diff is collapsed.
Click to expand it.
models/dst_regr_class.pth
View file @
801ae04a
No preview for this file type
models/dst_regr_class_nsc.pth
0 → 100644
View file @
801ae04a
File added
scripts/training_reg_class.py
View file @
801ae04a
...
...
@@ -128,10 +128,11 @@ data_out_c[np.where((data_out_c < dst_levels[2]))] = 3
class
Dataset
(
utils_data
.
Dataset
):
def
__init__
(
self
,
dataset_in
,
dataset_out
,
dataset_out_c
):
def
__init__
(
self
,
dataset_in
,
dataset_out
,
dataset_out_c
,
weights
):
self
.
dataset_in
=
dataset_in
self
.
dataset_out
=
dataset_out
self
.
dataset_out_c
=
dataset_out_c
self
.
weights
=
weights
def
__len__
(
self
):
return
self
.
dataset_in
.
size
(
0
)
...
...
@@ -141,7 +142,8 @@ class Dataset(utils_data.Dataset):
din_src
=
self
.
dataset_in
[
idx
]
dout
=
self
.
dataset_out
[
idx
]
dout_c
=
self
.
dataset_out_c
[
idx
]
return
din_src
,
dout
,
dout_c
ww
=
self
.
weights
[
idx
]
return
din_src
,
dout
,
dout_c
,
ww
ixs_valid_test
=
np
.
arange
(
int
(
len_valid_test
))
+
last_train
np
.
random
.
shuffle
(
ixs_valid_test
)
...
...
@@ -172,7 +174,7 @@ sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples= l
BATCH_SIZE
=
256
dataset_tr
=
Dataset
(
data_in_scaled
[:
last_train
],
data_out_scaled
[:
last_train
],
data_out_c
[:
last_train
])
dataset_tr
=
Dataset
(
data_in_scaled
[:
last_train
],
data_out_scaled
[:
last_train
],
data_out_c
[:
last_train
]
,
weights
)
# data_loader_tr = utils_data.DataLoader(dataset_tr, batch_size=BATCH_SIZE, num_workers = 4, shuffle=False, sampler = sampler)
data_loader_tr
=
utils_data
.
DataLoader
(
dataset_tr
,
batch_size
=
BATCH_SIZE
,
num_workers
=
4
,
shuffle
=
True
)
...
...
@@ -229,10 +231,10 @@ class DSTnet(nn.Module):
return
x1
,
x2
aa
=
data_out_c
[:
last_train
]
weights_c
=
torch
.
tensor
([
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
0
]),
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
1
]),
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
2
]),
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
3
])]).
to
(
device
).
sqrt
()
weights_c
=
torch
.
tensor
([
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
0
]),
len
(
aa
[
aa
==
1
])
/
len
(
aa
[
aa
==
1
]),
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
2
]),
len
(
aa
[
aa
==
0
])
/
len
(
aa
[
aa
==
3
])]).
to
(
device
).
sqrt
()
loss_f
=
nn
.
L1Loss
()
loss_mse
=
nn
.
MSELoss
()
loss_mse
=
nn
.
MSELoss
(
reduction
=
'none'
)
#loss_fc= nn.CrossEntropyLoss()
loss_fc
=
nn
.
CrossEntropyLoss
(
weight
=
weights_c
)
...
...
@@ -263,14 +265,15 @@ for epoch in range(num_epochs):
x
=
batch
[
0
].
float
().
to
(
device
)
y_r
=
batch
[
1
].
float
().
to
(
device
)
y_c
=
batch
[
2
].
flatten
().
long
().
to
(
device
)
w
=
batch
[
3
].
to
(
device
)
optimizer
.
zero_grad
()
dst_net
.
train
()
out_r
,
out_c
=
dst_net
(
x
)
loss_r
=
loss_f
(
out_r
,
y_r
)
loss_c
=
loss_fc
(
out_c
,
y_c
)
loss
=
loss_r
+
loss_c
loss
=
(
loss_r
*
w
).
mean
()
+
loss_c
loss
.
backward
()
optimizer
.
step
()
...
...
@@ -278,17 +281,17 @@ for epoch in range(num_epochs):
dst_net
.
eval
()
out_r
,
out_c
=
dst_net
(
data_in_scaled
[:
last_train
].
to
(
device
).
float
())
loss_tr
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[:
last_train
].
to
(
device
).
float
()).
item
())
loss_tr
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[:
last_train
].
to
(
device
).
float
()).
mean
().
item
())
loss_mae_tr
=
loss_f
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[:
last_train
].
to
(
device
).
float
()).
item
()
loss_c_tr
=
loss_fc
(
out_c
,
data_out_c
[:
last_train
].
flatten
().
long
().
to
(
device
)).
item
()
out_r
,
out_c
=
dst_net
(
data_in_scaled
[
ixs_valid
].
to
(
device
).
float
())
loss_val
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_valid
].
to
(
device
).
float
()).
item
())
loss_val
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_valid
].
to
(
device
).
float
()).
mean
().
item
())
loss_mae_val
=
loss_f
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_valid
].
to
(
device
).
float
()).
item
()
loss_c_val
=
loss_fc
(
out_c
,
data_out_c
[
ixs_valid
].
flatten
().
long
().
to
(
device
)).
item
()
out_r
,
out_c
=
dst_net
(
data_in_scaled
[
ixs_test
].
to
(
device
).
float
())
loss_ts
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_test
].
to
(
device
).
float
()).
item
())
loss_ts
=
np
.
sqrt
(
loss_mse
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_test
].
to
(
device
).
float
()).
mean
().
item
())
loss_mae_ts
=
loss_f
(
mm_scaler_out
.
inverse_transform
(
out_r
.
cpu
().
clone
()).
to
(
device
),
data_out
[
ixs_test
].
to
(
device
).
float
()).
item
()
loss_c_ts
=
loss_fc
(
out_c
,
data_out_c
[
ixs_test
].
flatten
().
long
().
to
(
device
)).
item
()
...
...
@@ -302,11 +305,11 @@ for epoch in range(num_epochs):
(
epoch
,
epoch_time
,
loss_tr
,
loss_val
,
loss_ts
,
loss_c_tr
,
loss_c_val
,
loss_c_ts
))
torch
.
save
(
dst_net
.
state_dict
(),
os
.
path
.
join
(
'/home/mcristofo/DST/models'
,
'dst_regr_class.pth'
))
torch
.
save
(
dst_net
.
state_dict
(),
os
.
path
.
join
(
'/home/mcristofo/DST/models'
,
'dst_regr_class
_nsc
.pth'
))
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_tr_regr_class.txt'
),
history_tr
)
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_valid_regr_class.txt'
),
history_valid
)
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_ts_regr_class.txt'
),
history_ts
)
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_tr_regr_class
_nsc
.txt'
),
history_tr
)
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_valid_regr_class
_nsc
.txt'
),
history_valid
)
np
.
savetxt
(
os
.
path
.
join
(
'/home/mcristofo/DST/hist'
,
'history_ts_regr_class
_nsc
.txt'
),
history_ts
)
dst_net
.
eval
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment