Commit e4c2c3c0 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Accept *args and **kwargs in create_dataset + black formatting

parent 44fec4c8
......@@ -23,17 +23,20 @@ def find_dataset_using_name(dataset_name):
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError('In {0}.py, there should be a subclass of BaseDataset with class name that matches {1} in lowercase.'.format(dataset_filename, target_dataset_name))
raise NotImplementedError(
'In {0}.py, there should be a subclass of BaseDataset with class name that matches {1} in lowercase.'.format(
dataset_filename, target_dataset_name
)
)
return dataset
def create_dataset(configuration):
def create_dataset(configuration, *args, **kwargs):
"""Create a dataset given the configuration (loaded from the json file).
This function wraps the class CustomDatasetDataLoader.
......@@ -41,35 +44,40 @@ def create_dataset(configuration):
Example:
from datasets import create_dataset
dataset = create_dataset(configuration)
dataset = create_dataset(configuration, *args, **kwargs)
"""
data_loader = CustomDatasetDataLoader(configuration)
data_loader = CustomDatasetDataLoader(configuration, *args, **kwargs)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
class CustomDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threaded data loading
according to the configuration.
"""
def __init__(self, configuration):
def __init__(self, configuration, *args, **kwargs):
self.configuration = configuration
dataset_class = find_dataset_using_name(configuration['dataset_name'])
self.dataset = dataset_class(configuration)
self.dataset = dataset_class(configuration, *args, **kwargs)
print("dataset [{0}] was created".format(type(self.dataset).__name__))
# if we use custom collation, define it as a staticmethod in the dataset class
custom_collate_fn = getattr(self.dataset, "collate_fn", None)
if callable(custom_collate_fn):
self.dataloader = data.DataLoader(self.dataset, **configuration['loader_params'], collate_fn=custom_collate_fn)
self.dataloader = data.DataLoader(
self.dataset,
**configuration['loader_params'],
collate_fn=custom_collate_fn
)
else:
self.dataloader = data.DataLoader(self.dataset, **configuration['loader_params'])
self.dataloader = data.DataLoader(
self.dataset, **configuration['loader_params']
)
def load_data(self):
return self
def get_custom_dataloader(self, custom_configuration):
"""Get a custom dataloader (e.g. for exporting the model).
This dataloader may use different configurations than the
......@@ -77,18 +85,22 @@ class CustomDatasetDataLoader():
"""
custom_collate_fn = getattr(self.dataset, "collate_fn", None)
if callable(custom_collate_fn):
custom_dataloader = data.DataLoader(self.dataset, **self.configuration['loader_params'], collate_fn=custom_collate_fn)
custom_dataloader = data.DataLoader(
self.dataset,
**self.configuration['loader_params'],
collate_fn=custom_collate_fn
)
else:
custom_dataloader = data.DataLoader(self.dataset, **self.configuration['loader_params'])
custom_dataloader = data.DataLoader(
self.dataset, **self.configuration['loader_params']
)
return custom_dataloader
def __len__(self):
"""Return the number of data in the dataset.
"""
return len(self.dataset)
def __iter__(self):
"""Return a batch of data.
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment