Source code for heat.sparse.dcsx_matrix

"""Provides DCSR_matrix, a distributed compressed sparse row matrix"""

from __future__ import annotations

import torch
from mpi4py import MPI
from typing import Union, Tuple, TypeVar

from ..core.devices import Device
from ..core.dndarray import DNDarray
from ..core.factories import array
from ..core.types import datatype, canonical_heat_type

__all__ = ["DCSR_matrix", "DCSC_matrix"]

Communication = TypeVar("Communication")


class __DCSX_matrix:
    """
    Distributed Compressed Sparse Matrix. Base class for DCSR_matrix and DCSC_matrix.

    Parameters
    ----------
    array : torch.Tensor (layout ==> torch.sparse_csr | torch.sparse_csc)
        Local sparse array
    gnnz: int
        Total number of non-zero elements across all processes
    gshape : Tuple[int,...]
        The global shape of the array
    dtype : datatype
        The datatype of the array
    split : int or None
        If split is not None, it denotes the axis on which the array is divided between processes.
        DCSR_matrix only supports distribution along axis 0.
    device : Device
        The device on which the local arrays are using (cpu or gpu)
    comm : Communication
        The communications object for sending and receiving data
    balanced: bool or None
        Describes whether the data are evenly distributed across processes.
    """

    def __init__(
        self,
        array: torch.Tensor,
        gnnz: int,
        gshape: Tuple[int, ...],
        dtype: datatype,
        split: Union[int, None],
        device: Device,
        comm: Communication,
        balanced: bool,
    ):
        self._array = array
        self._gnnz = gnnz
        self._gshape = gshape
        self._dtype = dtype
        self._split = split
        self._device = device
        self._comm = comm
        self._balanced = balanced

    def global_indptr(self) -> DNDarray:
        """
        Global indptr of the ``__DCSX_matrix`` as a ``DNDarray``
        """
        if self.split is None:
            raise ValueError("This method works only for distributed matrices")

        # Need to know the number of non-zero elements
        # in the processes with lesser rank
        all_nnz = torch.zeros(self.comm.size + 1, device=self.device.torch_device)

        # Each process must drop their nnz in index = rank + 1
        all_nnz[self.comm.rank + 1] = self.lnnz
        self.comm.Allreduce(MPI.IN_PLACE, all_nnz, MPI.SUM)

        # Build prefix array out of all the nnz
        all_nnz = torch.cumsum(all_nnz, dim=0)

        global_indptr = self.lindptr + int(all_nnz[self.comm.rank])

        # Remove the (n+1) the element from all the processes except last
        if self.comm.rank != self.comm.size - 1:
            global_indptr = global_indptr[:-1]

        # NOTE: indptr might be unbalanced in distribution but should not be self balanced
        return array(
            global_indptr,
            dtype=self.lindptr.dtype,
            device=self.device,
            comm=self.comm,
            is_split=(
                0 if self.split is not None else None
            ),  # is_split for the indptr is either 0 or None because it is always 1 dimensional
        )

    @property
    def balanced(self) -> bool:
        """
        Boolean value indicating if the __DCSX_matrix is balanced between the MPI processes
        """
        return self._balanced

    @property
    def comm(self) -> Communication:
        """
        The :class:`~heat.core.communication.Communication` of the ``__DCSX_matrix``
        """
        return self._comm

    @property
    def device(self) -> Device:
        """
        The :class:`~heat.core.devices.Device` of the ``__DCSX_matrix``
        """
        return self._device

    @property
    def larray(self) -> torch.Tensor:
        """
        Local data of the ``__DCSX_matrix``
        """
        return self._array

    @property
    def data(self) -> torch.Tensor:
        """
        Global data of the ``__DCSX_matrix``
        """
        if self.split is None:
            return self.ldata

        data_buffer = torch.zeros(
            size=(self.gnnz,), dtype=self.dtype.torch_type(), device=self.device.torch_device
        )
        counts, displs = self.counts_displs_nnz()
        self.comm.Allgatherv(self.ldata, (data_buffer, counts, displs))
        return data_buffer

    @property
    def gdata(self) -> torch.Tensor:
        """
        Global data of the ``__DCSX_matrix``
        """
        return self.data

    @property
    def ldata(self) -> torch.Tensor:
        """
        Local data of the ``__DCSX_matrix``
        """
        return self._array.values()

    @property
    def indptr(self) -> torch.Tensor:
        """
        Global indptr of the ``__DCSX_matrix``
        """
        if self.split is None:
            return self.lindptr

        return self.global_indptr().resplit(axis=None).larray

    @property
    def gindptr(self) -> torch.Tensor:
        """
        Global indptr of the ``__DCSX_matrix``
        """
        return self.indptr

    @property
    def lindptr(self) -> torch.Tensor:
        """
        Local indptr of the ``__DCSX_matrix``
        """
        raise NotImplementedError("Local indptr is not implemented for __DCSX_matrix")

    @property
    def indices(self) -> torch.Tensor:
        """
        Global indices of the ``__DCSX_matrix``
        """
        if self.split is None:
            return self.lindices

        indices_buffer = torch.zeros(
            size=(self.gnnz,), dtype=self.lindices.dtype, device=self.device.torch_device
        )
        counts, displs = self.counts_displs_nnz()
        self.comm.Allgatherv(self.lindices, (indices_buffer, counts, displs))
        return indices_buffer

    @property
    def gindices(self) -> torch.Tensor:
        """
        Global indices of the ``__DCSX_matrix``
        """
        return self.indices

    @property
    def lindices(self) -> torch.Tensor:
        """
        Local indices of the ``__DCSX_matrix``
        """
        raise NotImplementedError("Local indices is not implemented for __DCSX_matrix")

    @property
    def ndim(self) -> int:
        """
        Number of dimensions of the ``__DCSX_matrix``
        """
        return len(self._gshape)

    @property
    def nnz(self) -> int:
        """
        Total number of non-zero elements of the ``__DCSX_matrix``
        """
        return self._gnnz

    @property
    def gnnz(self) -> int:
        """
        Total number of non-zero elements of the ``__DCSX_matrix``
        """
        return self.nnz

    @property
    def lnnz(self) -> int:
        """
        Number of non-zero elements on the local process of the ``__DCSX_matrix``
        """
        return self._array._nnz()

    @property
    def shape(self) -> Tuple[int, ...]:
        """
        Global shape of the ``__DCSX_matrix``
        """
        return self._gshape

    @property
    def gshape(self) -> Tuple[int, ...]:
        """
        Global shape of the ``__DCSX_matrix``
        """
        return self.shape

    @property
    def lshape(self) -> Tuple[int, ...]:
        """
        Local shape of the ``__DCSX_matrix``
        """
        return tuple(self._array.size())

    @property
    def dtype(self):
        """
        The :class:`~heat.core.types.datatype` of the ``__DCSX_matrix``
        """
        return self._dtype

    @property
    def split(self) -> int:
        """
        Returns the axis on which the ``__DCSX_matrix`` is split
        """
        return self._split

    def is_distributed(self) -> bool:
        """
        Determines whether the data of this ``__DCSX_matrix`` is distributed across multiple processes.
        """
        return self.split is not None and self.comm.is_distributed()

    def counts_displs_nnz(self) -> Tuple[Tuple[int], Tuple[int]]:
        """
        Returns actual counts (number of non-zero items per process) and displacements (offsets) of the __DCSX_matrix.
        Does not assume load balance.
        """
        if self.split is not None:
            counts = torch.zeros(self.comm.size)
            counts[self.comm.rank] = self.lnnz
            self.comm.Allreduce(MPI.IN_PLACE, counts, MPI.SUM)
            displs = [0] + torch.cumsum(counts, dim=0)[:-1].tolist()
            return tuple(counts.tolist()), tuple(displs)
        else:
            raise ValueError(
                f"Non-distributed {self.__class__.__name__}. Cannot calculate counts and displacements."
            )

    def astype(self, dtype, copy=True) -> __DCSX_matrix:
        """
        Returns a casted version of this matrix.
        Casted matrix is a new matrix of the same shape but with given type of this matrix. If copy is ``True``, the
        same matrix is returned instead.

        Parameters
        ----------
        dtype : datatype
            HeAT type to which the matrix is cast
        copy : bool, optional
            By default the operation returns a copy of this matrix. If copy is set to ``False`` the cast is performed
            in-place and this matrix is returned
        """
        dtype = canonical_heat_type(dtype)
        casted_matrix = self._array.to(dtype.torch_type(), copy=copy)
        if copy:
            return self.__class__(
                casted_matrix,
                self.gnnz,
                self.gshape,
                dtype,
                self.split,
                self.device,
                self.comm,
                self.balanced,
            )

        self._array = casted_matrix
        self._dtype = dtype

        return self

    def __repr__(self) -> str:
        """
        Computes a printable representation of the passed __DCSX_matrix.
        """
        print_string = (
            f"(indptr: {self.indptr}, indices: {self.indices}, data: {self.data}, "
            f"dtype=ht.{self.dtype.__name__}, device={self.device}, split={self.split})"
        )

        # Check has to happen after generating string because
        # generation of string invokes functions that require
        # participation from all processes
        if self.comm.rank != 0:
            return ""
        return print_string


[docs] class DCSR_matrix(__DCSX_matrix): """ Distributed Compressed Sparse Row Matrix. It is composed of PyTorch sparse_csr_tensors local to each process. Parameters ---------- array : torch.Tensor (layout ==> torch.sparse_csr) Local sparse array gnnz: int Total number of non-zero elements across all processes gshape : Tuple[int,...] The global shape of the array dtype : datatype The datatype of the array split : int or None If split is not None, it denotes the axis on which the array is divided between processes. DCSR_matrix only supports distribution along axis 0. device : Device The device on which the local arrays are using (cpu or gpu) comm : Communication The communications object for sending and receiving data balanced: bool or None Describes whether the data are evenly distributed across processes. """ @property def lindptr(self) -> torch.Tensor: """ Local indptr of the ``DCSR_matrix`` """ return self._array.crow_indices() @property def lindices(self) -> torch.Tensor: """ Local indices of the ``DCSR_matrix`` """ return self._array.col_indices()
[docs] class DCSC_matrix(__DCSX_matrix): """ Distributed Compressed Sparse Column Matrix. It is composed of PyTorch sparse_csc_tensors local to each process. Parameters ---------- array : torch.Tensor (layout ==> torch.sparse_csc) Local sparse array gnnz: int Total number of non-zero elements across all processes gshape : Tuple[int,...] The global shape of the array dtype : datatype The datatype of the array split : int or None If split is not None, it denotes the axis on which the array is divided between processes. DCSR_matrix only supports distribution along axis 0. device : Device The device on which the local arrays are using (cpu or gpu) comm : Communication The communications object for sending and receiving data balanced: bool or None Describes whether the data are evenly distributed across processes. """ @property def lindptr(self) -> torch.Tensor: """ Local indptr of the ``DCSC_matrix`` """ return self._array.ccol_indices() @property def lindices(self) -> torch.Tensor: """ Local indices of the ``DCSC_matrix`` """ return self._array.row_indices()