Commit 433f429e authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Dataset augmentation script

parent 603710b7
This diff is collapsed.
# %% [markdown]
# ## Augment dataset class-wise
# %%
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from dataset import NumpyCSVDataset, augment_3D_HN
import SimpleITK as sitk
from config import get_project_root
#%%
PROJECT_ROOT = get_project_root()
DATASET = 'HN_val'
BBOX_DIR = PROJECT_ROOT / 'data' / DATASET / 'processed' / 'bbox'
BBOX_SUBDATASET = 'bbox_64'
BBOX_SUBDATASET_AUG = f'{BBOX_SUBDATASET}_aug'
IN_DIR = BBOX_DIR / BBOX_SUBDATASET
OUT_DIR = BBOX_DIR / BBOX_SUBDATASET_AUG
os.makedirs(OUT_DIR, exist_ok=False)
CLINICAL_DATA_PATH = (
PROJECT_ROOT / 'data' / DATASET / 'processed' / f'clinical_{DATASET}.csv'
)
CLINICAL_DATA_AUG_PATH = (
PROJECT_ROOT
/ 'data'
/ DATASET
/ 'processed'
/ f'clinical_{BBOX_SUBDATASET_AUG}.csv'
)
LABEL_COLUMN = 'locoregional'
SIZE = 64
dataset = NumpyCSVDataset(IN_DIR, CLINICAL_DATA_PATH, LABEL_COLUMN, SIZE, mode='test')
clinical_original = pd.read_csv(CLINICAL_DATA_PATH)
labels = dataset.labels
idx_to_augment = np.where(labels == 1)[0]
ratio_NP = int((len(labels) - len(idx_to_augment)) / len(idx_to_augment))
print('Augmentation factor: ', ratio_NP)
clinical_augumented_rows = []
#%%
for sample in tqdm(dataset):
filename = sample['filename']
filename_no_ext = filename.split('.')[-2]
# print(filename)
image_orig = sample['data']
label = sample['target']
ratio = 1 if label == 0 else ratio_NP
clinical_filename = clinical_original[clinical_original['filename'] == filename]
assert (
len(clinical_filename) == 1
), f'There is more than one row corresponding to the filename {filename} in the clinical file.'
id_image = 0
for j in range(ratio):
image_aug = augment_3D_HN(image_orig, 'train', SIZE)
filename_aug = f'{filename_no_ext}_{id_image}.npy'
np.save(OUT_DIR / filename_aug, image_aug)
clinical_filename['filename'] = filename_aug
clinical_augumented_rows.append(clinical_filename)
id_image += 1
# %%
clinical_augmented = pd.concat(clinical_augumented_rows, ignore_index=True)
#%%
clinical_augmented.to_csv(CLINICAL_DATA_AUG_PATH, index=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment