heat.utils.data.datatools
Function and classes useful for loading data into neural networks
Module Contents
- class DataLoader(dataset: torch.utils.data.Dataset | heat.utils.data.partial_dataset.PartialH5Dataset, batch_size: int = 1, num_workers: int = 0, collate_fn: Callable = None, pin_memory: bool = False, drop_last: bool = False, timeout: int | float = 0, worker_init_fn: Callable = None)
The combines either a
DNDarray
or a torch Dataset with a sampler. This provides an iterable over the local dataset and it will shuffle the data at the end of the iterator. If aDNDarray
is given, then aDataset()
will be created internally.Currently, this only supports only map-style datasets with single-process loading. It uses the random batch sampler. The rest of the
DataLoader
functionality mentioned in torch.utils.data.dataloader applies.- Parameters:
dataset –
Dataset()
, torch Dataset,heat.utils.data.partial_dataset.PartialH5Dataset()
A torch dataset from which the data will be returned by the created iteratorbatch_size –
int, optional How many samples per batch to loadn
Default: 1
num_workers – int, optional How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.n Default: 0
collate_fn – callable, optional Merges a list of samples to form a mini-batch of torch.Tensor(s). Used when using batched loading from a map-style dataset.n Default: None
pin_memory – bool, optional If
True
, the data loader will copy torch.Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or yourcollate_fn
returns a batch that is a custom type, see the example below. n Default: Falsedrop_last – bool, optional Set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller.n Default:False
timeout – int or float, optional If positive, the timeout value for collecting a batch from workers. Should always be non-negative.n Default: 0
worker_init_fn – callable, optional If not
None
, this will be called on each worker subprocess with the worker id (an int in[0, num_workers - 1]
) as input, after seeding and before data loading.n default: None
- Variables:
dataset – The dataset created from the local data
DataLoader – The local DataLoader object. Used in the creation of the iterable and the length
_first_iter (bool) – Flag indicating if the iterator created is the first one. If it is not, then the data will be shuffled before the iterator is created
last_epoch (bool) – Flag indicating last epoch
- __iter__() Iterator
Generate a new iterator of a type dependent on the type of dataset. Returns a
partial_dataset.PartialH5DataLoaderIter
if the dataset is apartial_dataset.PartialH5Dataset
self._full_dataset_shuffle_iter()
otherwise
- __len__() int
Get the length of the dataloader. Returns the number of batches.
- _full_dataset_shuffle_iter()
- class Dataset(array, transforms: List | Callable | None = None, ishuffle: bool | None = False, test_set: bool | None = False)
Bases:
torch.utils.data.Dataset
An abstract class representing a given dataset. This inherits from torch.utils.data.Dataset.
This class is a general example for what should be done to create a Dataset. When creating a dataset all of the standard attributes should be set, the
__getitem__
,__len__
, andshuffle
functions must be defined.__getitem__
: how an item is given to the network__len__
: the number of data elements to be given to the network in totalShuffle()
: how the data should be shuffled between the processes. The function shown below is for a dataset composed of only data and without targets. The functiondataset_shuffle()
abstracts this. For this function only the dataset and a list of attributes to shuffle are given.nIshuffle()
: A non-blocking version ofShuffle()
, this is handled in the abstract functiondataset_ishuffle()
. It works similarly todataset_shuffle()
.
As the amount of data across processes can be non-uniform, the dataset class will slice off the remaining elements on whichever processes have more data than the others. This should only be 1 element. The shuffle function will shuffle all of the data on the process.
It is recommended that for
DNDarray
s, the split is either 0 or None- Parameters:
array (DNDarray) – DNDarray for which to great the dataset
transform (Callable) – Transformation to call before a data item is returned
ishuffle (bool, optional) – flag indicating whether to use non-blocking communications for shuffling the data between epochs Note: if
True
, theIshuffle()
function must be defined within the classn Default: False
- Variables:
attributes. (These are the required)
htdata (DNDarray) – Full data
_cut_slice (slice) – Slice to cut off the last element to get a uniform amount of data on each process
comm (MPICommunicator) – Communication object used to send the data between processes
lcl_half (int) – Half of the number of data elements on the process
data (torch.Tensor) – The local data to be used in training
transforms (Callable) – Transform to be called during the getitem function
ishuffle (bool) – Flag indicating if non-blocking communications are used for shuffling the data between epochs
- __getitem__(index: int | slice | tuple | list | torch.Tensor) torch.Tensor
This is the most basic form of getitem. As the dataset is often very specific to the dataset, this should be overwritten by the user. In this form it only gets the raw items from the data.
- __len__() int
Get the number of items in the dataset. This should be overwritten by custom datasets
- Shuffle()
Send half of the local data to the process
self.comm.rank + 1
if available, else wrap around. After receiving the new data, shuffle the local tensor.
- Ishuffle()
Send half of the local data to the process
self.comm.rank + 1
if available, else wrap around. After receiving the new data, shuffle the local tensor.
- dataset_shuffle(dataset: Dataset | torch.utils.data.Dataset, attrs: List[list])
Shuffle the given attributes of a dataset across multiple processes. This will send half of the data to rank + 1. Once the new data is received, it will be shuffled into the existing data on the process. This function will be called by the DataLoader automatically if
dataset.ishuffle = False
. attrs should have the form [[torch.Tensor, DNDarray], … i.e. [[‘data’, ‘htdata`]] assume that all of the attrs have the same dim0 shape as the local data- Parameters:
dataset (Dataset) – the dataset to shuffle
attrs (List[List[str, str], ... ]) – List of lists each of which contains 2 strings. The strings are the handles corresponding to the Dataset attributes corresponding to the global data DNDarray and the local data of that array, i.e. [[“data, “htdata”],] would shuffle the htdata around and set the correct amount of data for the
dataset.data
attribute. For multiple parameters multiple lists are required. I.e. [[“data”, “htdata”], [“targets”, “httargets”]]
Notes
dataset.comm
must be defined for this function to work.
- dataset_ishuffle(dataset: Dataset | torch.utils.data.Dataset, attrs: List[list])
Shuffle the given attributes of a dataset across multiple processes, using non-blocking communications. This will send half of the data to rank + 1. The data must be received by the
dataset_irecv()
function.This function will be called by the DataLoader automatically if
dataset.ishuffle = True
. This is set either during the definition of the class of its initialization by a given paramete.- Parameters:
dataset (Dataset) – the dataset to shuffle
attrs (List[List[str, str], ... ]) – List of lists each of which contains 2 strings. The strings are the handles corresponding to the Dataset attributes corresponding to the global data DNDarray and the local data of that array, i.e. [[“htdata, “data”],] would shuffle the htdata around and set the correct amount of data for the
dataset.data
attribute. For multiple parameters multiple lists are required. I.e. [[“htdata”, “data”], [“httargets”, “targets”]]
Notes
dataset.comm
must be defined for this function to work.