Source code for heat.utils.data.mnist

"""
File for the MNIST dataset definition in heat
"""

import torch

from torchvision import datasets
from typing import Callable, Union

from ...core import factories
from . import datatools

__all__ = ["MNISTDataset"]


[docs] class MNISTDataset(datasets.MNIST): """ Dataset wrapper for `torchvision.datasets.MNIST <https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.MNIST>`_. This implements all of the required functions mentioned in :class:`heat.utils.data.Dataset`. The ``__getitem__`` and ``__len__`` functions are inherited from `torchvision.datasets.MNIST <https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.MNIST>`_. Parameters ---------- root : str Directory containing the MNIST dataset train : bool, optional If the data is the training dataset or not, default is True transform : Callable, optional Transform to be applied to the data dataset in the ``__getitem__`` function, default is ``None`` target_transform : Callable, optional Transform to be applied to the target dataset in the ``__getitem__`` function, default is ``None`` download : bool, optional If the data does not exist in the directory, download it if True (default) split : int, optional On which access to split the data when it is loaded into a ``DNDarray`` ishuffle : bool, optional Flag indicating whether to use non-blocking communications for shuffling the data between epochs Note: if True, the ``Ishuffle()`` function must be defined within the class Default: ``False`` test_set : bool, optional If this dataset is the testing set then keep all of the data local Default: ``False`` Attributes ---------- htdata : DNDarray full data httargets : DNDarray full target data comm : communication.MPICommunicator heat communicator for sending data between processes _cut_slice : slice slice to remove the last element if all are not equal in length lcl_half : int integer value of half of the data on the process data : torch.Tensor the local data on a process targets : torch.Tensor the local targets on a process ishuffle : bool flag indicating if non-blocking communications are used for shuffling the data between epochs test_set : bool if this dataset is the testing set then keep all of the data local Notes ----- For other attributes see `torchvision.datasets.MNIST <https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.MNIST>`_. """ def __init__( self, root: str, train: bool = True, transform: Callable = None, target_transform: Callable = None, download: bool = True, split: int = 0, ishuffle: bool = False, test_set: bool = False, ): # noqa: D107 super().__init__( root, train=train, transform=transform, target_transform=target_transform, download=download, ) if split != 0 and split is not None: raise ValueError("split must be 0 or None") split = None if test_set else split array = factories.array(self.data, split=split) targets = factories.array(self.targets, split=split) self.test_set = test_set self.partial_dataset = False self.comm = array.comm self.htdata = array self.httargets = targets self.ishuffle = ishuffle if split is not None: min_data_split = array.gshape[0] // array.comm.size arb_slice = slice(min_data_split) self._cut_slice = arb_slice self.lcl_half = min_data_split // 2 self.data = array._DNDarray__array[self._cut_slice] self.targets = targets._DNDarray__array[self._cut_slice] else: self._cut_slice = None self.lcl_half = array.gshape[0] // 2 self.data = array._DNDarray__array self.targets = targets._DNDarray__array # getitem and len are defined by torch's MNIST class
[docs] def Shuffle(self): """ Uses the :func:`datatools.dataset_shuffle` function to shuffle the data between the processes """ if not self.test_set: datatools.dataset_shuffle( dataset=self, attrs=[["data", "htdata"], ["targets", "httargets"]] )
[docs] def Ishuffle(self): """ Uses the :func:`datatools.dataset_ishuffle` function to shuffle the data between the processes """ if not self.test_set: datatools.dataset_ishuffle( dataset=self, attrs=[["data", "htdata"], ["targets", "httargets"]] )