heat.utils.data.mnist
File for the MNIST dataset definition in heat
Module Contents
- class MNISTDataset(root: str, train: bool = True, transform: Callable = None, target_transform: Callable = None, download: bool = True, split: int = 0, ishuffle: bool = False, test_set: bool = False)
Bases:
torchvision.datasets.MNIST
Dataset wrapper for torchvision.datasets.MNIST. This implements all of the required functions mentioned in
heat.utils.data.Dataset
. The__getitem__
and__len__
functions are inherited from torchvision.datasets.MNIST.- Parameters:
root (str) – Directory containing the MNIST dataset
train (bool, optional) – If the data is the training dataset or not, default is True
transform (Callable, optional) – Transform to be applied to the data dataset in the
__getitem__
function, default isNone
target_transform (Callable, optional) – Transform to be applied to the target dataset in the
__getitem__
function, default isNone
download (bool, optional) – If the data does not exist in the directory, download it if True (default)
split (int, optional) – On which access to split the data when it is loaded into a
DNDarray
ishuffle (bool, optional) – Flag indicating whether to use non-blocking communications for shuffling the data between epochs Note: if True, the
Ishuffle()
function must be defined within the class Default:False
test_set (bool, optional) – If this dataset is the testing set then keep all of the data local Default:
False
- Variables:
htdata (DNDarray) – full data
httargets (DNDarray) – full target data
comm (communication.MPICommunicator) – heat communicator for sending data between processes
_cut_slice (slice) – slice to remove the last element if all are not equal in length
lcl_half (int) – integer value of half of the data on the process
data (torch.Tensor) – the local data on a process
targets (torch.Tensor) – the local targets on a process
ishuffle (bool) – flag indicating if non-blocking communications are used for shuffling the data between epochs
test_set (bool) – if this dataset is the testing set then keep all of the data local
Notes
For other attributes see torchvision.datasets.MNIST.
- Shuffle()
Uses the
datatools.dataset_shuffle()
function to shuffle the data between the processes
- Ishuffle()
Uses the
datatools.dataset_ishuffle()
function to shuffle the data between the processes