Commit 2db327de authored by Gabriele Franch's avatar Gabriele Franch
Browse files

removed padding

parent f9fcad05
......@@ -136,9 +136,10 @@ def wrapper_train(model: Model):
for itr in progress_bar:
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
imgs = padding_taasss(imgs, args)
# imgs = padding_taasss(imgs, args)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
eta, real_input_flag = schedule_sampling(eta, itr)
print(real_input_flag.shape)
cost = trainer.train(model, imgs, real_input_flag, args)
progress_bar.set_description(f"Loss: {cost}")
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -7,8 +7,10 @@ from data_provider.CIKM.taasss import infinite_batcher
def get_batcher(args):
data_dir = Path("/") / "data1" / "meteotn_data_2010_2016"
metadata_file = data_dir / "run_metadata.csv"
data_dir = Path("/") / "home" / "gabriele" / "Documents" / "dottorato" / "data" / "meteotn_data_new"
metadata_file = data_dir / "hdf_metadata.csv"
# data_dir = Path("/") / "data1" / "meteotn_data_2010_2016"
# metadata_file = data_dir / "run_metadata.csv"
all_data = h5py.File(
data_dir / "hdf_archives" / "all_data.hdf5",
"r",
......@@ -38,9 +40,9 @@ def get_batcher(args):
def padding_taasss(array: np.ndarray, args) -> np.ndarray:
"""
Add padding
(1, 25, 512, 512, 1)
(4, 25, 480, 480, 1)
to
(1, 25, 480, 480, 1)
(4, 25, 512, 512, 1)
"""
zeros = np.zeros((args.batch_size, args.total_length, args.img_width, args.img_width, args.img_channel))
if args.img_width == 512:
......@@ -55,9 +57,9 @@ def padding_taasss(array: np.ndarray, args) -> np.ndarray:
def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
"""
Remove padding
(1, 25, 480, 480, 1)
(4, 25, 512, 512, 1)
to
(1, 25, 512, 512, 1)
(4, 25, 480, 480, 1)
"""
if args.img_width == 512:
return array[:, :, 16:496, 16:496, :]
......@@ -69,7 +71,6 @@ def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
def change_taasss_dims(a: np.ndarray) -> np.ndarray:
"""(25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)"""
a = np.squeeze(a)
a = np.expand_dims(a, axis=0)
return np.expand_dims(a, axis=4)
"""(25, 4, 1, 480, 480) to (4, 25, 480, 480, 1)"""
a = np.moveaxis(a, 2, 4)
return np.swapaxes(a, 0, 1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment