"""
Tiling functions/classes. With these classes, you can classes you can address blocks of data in a DNDarray
"""
from __future__ import annotations
import itertools
import torch
from mpi4py import MPI
from typing import List, Tuple, Union
from .dndarray import DNDarray
from .communication import MPICommunication
__all__ = ["SplitTiles", "SquareDiagTiles"]
[docs]
class SplitTiles:
"""
Initialize tiles with the tile divisions equal to the theoretical split dimensions in
every dimension
Parameters
----------
arr : DNDarray
Base array for which to create the tiles
Attributes
----------
__DNDarray : DNDarray
the ``DNDarray`` associated with the tiles
__lshape_map : torch.Tensor
map of the shapes of the local torch tensors of arr
__tile_locations : torch.Tensor
locations of the tiles of ``arr``
__tile_ends_g : torch.Tensor
the global indices of the ends of the tiles
__tile_dims : torch.Tensor
the dimensions of all of the tiles
Examples
--------
>>> a = ht.zeros(
... (
... 10,
... 11,
... ),
... split=None,
... )
>>> a.create_split_tiles()
>>> print(a.tiles.tile_ends_g)
[0/2] tensor([[ 4, 7, 10],
[0/2] [ 4, 8, 11]], dtype=torch.int32)
[1/2] tensor([[ 4, 7, 10],
[1/2] [ 4, 8, 11]], dtype=torch.int32)
[2/2] tensor([[ 4, 7, 10],
[2/2] [ 4, 8, 11]], dtype=torch.int32)
>>> print(a.tiles.tile_locations)
[0/2] tensor([[0, 0, 0],
[0/2] [0, 0, 0],
[0/2] [0, 0, 0]], dtype=torch.int32)
[1/2] tensor([[1, 1, 1],
[1/2] [1, 1, 1],
[1/2] [1, 1, 1]], dtype=torch.int32)
[2/2] tensor([[2, 2, 2],
[2/2] [2, 2, 2],
[2/2] [2, 2, 2]], dtype=torch.int32)
>>> a = ht.zeros((10, 11), split=1)
>>> a.create_split_tiles()
>>> print(a.tiles.tile_ends_g)
[0/2] tensor([[ 4, 7, 10],
[0/2] [ 4, 8, 11]], dtype=torch.int32)
[1/2] tensor([[ 4, 7, 10],
[1/2] [ 4, 8, 11]], dtype=torch.int32)
[2/2] tensor([[ 4, 7, 10],
[2/2] [ 4, 8, 11]], dtype=torch.int32)
>>> print(a.tiles.tile_locations)
[0/2] tensor([[0, 1, 2],
[0/2] [0, 1, 2],
[0/2] [0, 1, 2]], dtype=torch.int32)
[1/2] tensor([[0, 1, 2],
[1/2] [0, 1, 2],
[1/2] [0, 1, 2]], dtype=torch.int32)
[2/2] tensor([[0, 1, 2],
[2/2] [0, 1, 2],
[2/2] [0, 1, 2]], dtype=torch.int32)
"""
def __init__(self, arr: DNDarray) -> None: # noqa: D107
# 1. get the lshape map
# 2. get the split axis numbers for the other axes
# 3. build tile map
lshape_map = arr.create_lshape_map()
tile_dims = torch.zeros((arr.ndim, arr.comm.size), device=arr.device.torch_device)
if arr.split is not None:
tile_dims[arr.split] = lshape_map[..., arr.split]
w_size = arr.comm.size
for ax in range(arr.ndim):
if arr.split is None or ax != arr.split:
size = arr.gshape[ax]
chunk = size // w_size
remainder = size % w_size
tile_dims[ax] = chunk
tile_dims[ax][:remainder] += 1
tile_ends_g = torch.cumsum(tile_dims, dim=1).int()
# tile_ends_g is the global end points of the tiles in each dimension
# create a tensor for the process rank of all the tiles
tile_locations = self.set_tile_locations(split=arr.split, tile_dims=tile_dims, arr=arr)
self.__DNDarray = arr
self.__lshape_map = lshape_map
self.__tile_locations = tile_locations
self.__tile_ends_g = tile_ends_g
self.__tile_dims = tile_dims
[docs]
@staticmethod
def set_tile_locations(split: int, tile_dims: torch.Tensor, arr: DNDarray) -> torch.Tensor:
"""
Create a `torch.Tensor` which contains the locations of the tiles of ``arr`` for the given split
Parameters
----------
split : int
Target split dimension. Does not need to be equal to ``arr.split``
tile_dims : torch.Tensor
Tensor containing the sizes of the each tile
arr : DNDarray
Array for which the tiles are being created for
"""
# this is split off specifically for the resplit function
tile_locations = torch.zeros(
[tile_dims[x].numel() for x in range(arr.ndim)],
dtype=torch.int64,
device=arr.device.torch_device,
)
if split is None:
tile_locations += arr.comm.rank
return tile_locations
arb_slice = [slice(None)] * arr.ndim
for pr in range(1, arr.comm.size):
arb_slice[split] = pr
tile_locations[tuple(arb_slice)] = pr
return tile_locations
@property
def arr(self) -> DNDarray:
"""
Get the DNDarray associated with the tiling object
"""
return self.__DNDarray
@property
def lshape_map(self) -> torch.Tensor:
"""
Return the shape of all of the local torch.Tensors
"""
return self.__lshape_map
@property
def tile_locations(self) -> torch.Tensor:
"""
Get the ``torch.Tensor`` with the locations of the tiles for SplitTiles
Examples
--------
see :class:`~SplitTiles`
"""
return self.__tile_locations
@property
def tile_ends_g(self) -> torch.Tensor:
"""
Returns a ``torch.Tensor`` with the global indices with the end points of the tiles in every dimension
Examples
--------
see :func:`SplitTiles`
"""
return self.__tile_ends_g
@property
def tile_dimensions(self) -> torch.Tensor:
"""
Returns a ``torch.Tensor`` with the sizes of the tiles
"""
return self.__tile_dims
[docs]
def __getitem__(self, key: Union[int, slice, Tuple[Union[int, slice], ...]]) -> torch.Tensor:
"""
Getitem function for getting tiles. Returns the tile which is specified is returned, but only on the process which it resides
Parameters
----------
key : int or Tuple or Slice
Key which identifies the tile/s to get
Examples
--------
>>> test = torch.arange(np.prod([i + 6 for i in range(2)])).reshape(
... [i + 6 for i in range(2)]
... )
>>> a = ht.array(test, split=0).larray
[0/2] tensor([[ 0., 1., 2., 3., 4., 5., 6.],
[0/2] [ 7., 8., 9., 10., 11., 12., 13.]])
[1/2] tensor([[14., 15., 16., 17., 18., 19., 20.],
[1/2] [21., 22., 23., 24., 25., 26., 27.]])
[2/2] tensor([[28., 29., 30., 31., 32., 33., 34.],
[2/2] [35., 36., 37., 38., 39., 40., 41.]])
>>> a.create_split_tiles()
>>> a.tiles[:2, 2]
[0/2] tensor([[ 5., 6.],
[0/2] [12., 13.]])
[1/2] tensor([[19., 20.],
[1/2] [26., 27.]])
[2/2] None
>>> a = ht.array(test, split=1)
>>> a.create_split_tiles()
>>> a.tiles[1]
[0/2] tensor([[14., 15., 16.],
[0/2] [21., 22., 23.]])
[1/2] tensor([[17., 18.],
[1/2] [24., 25.]])
[2/2] tensor([[19., 20.],
[2/2] [26., 27.]])
"""
# todo: strides can be implemented with using a list of slices for each dimension
if not isinstance(key, (tuple, slice, int, torch.tensor)):
raise TypeError(f"key type not supported: {type(key)}")
arr = self.__DNDarray
# if arr.comm.rank not in self.tile_locations[key]:
# return None
# This filters out the processes which are not involved
# next need to get the local indices
# tile_ends_g has the end points, need to get the start and stop
if arr.comm.rank not in self.tile_locations[key]:
return None
arb_slices = self.__get_tile_slices(key)
return arr.larray[tuple(arb_slices)]
def __get_tile_slices(
self, key: Union[int, slice, Tuple[Union[int, slice], ...]]
) -> Tuple[slice, ...]:
"""
Create and return slices to convert a key from the tile indices to the normal indices
"""
arr = self.__DNDarray
arb_slices = [None] * arr.ndim
end_rank = (
max(self.tile_locations[key].unique())
if self.tile_locations[key].unique().numel() > 1
else self.tile_locations[key]
)
if isinstance(key, int):
key = [key]
if len(key) < arr.ndim or key[-1] is None:
lkey = list(key)
lkey.extend([slice(0, None)] * (arr.ndim - len(key)))
key = lkey
for d in range(arr.ndim):
# todo: implement advanced indexing (lists of positions to iterate through)
lkey = key
stop = self.tile_ends_g[d][lkey[d]].max().item()
stop = (
stop
if d != arr.split or stop is None
else self.lshape_map[end_rank][d].max().item()
)
if (
isinstance(lkey[d], slice)
and d != arr.split
and lkey[d].start != 0
and lkey[d].start is not None
):
# if the key is a slice in a dimension, and the start value of the slice is not 0,
# and d is not the split dimension (-> the tiles start at 0 on all tiles in the split dim)
start = self.tile_ends_g[d][lkey[d].start - 1].item()
elif isinstance(lkey[d], int) and lkey[d] > 0 and d != arr.split:
start = self.tile_ends_g[d][lkey[d] - 1].item()
elif (
isinstance(lkey[d], torch.Tensor)
and lkey[d].numel() == 1
and lkey[d] > 0
and d != arr.split
):
start = self.tile_ends_g[d][lkey[d] - 1].item()
else:
start = 0
arb_slices[d] = slice(start, stop)
return arb_slices
[docs]
def get_tile_size(
self, key: Union[int, slice, Tuple[Union[int, slice], ...]]
) -> Tuple[int, ...]:
"""
Get the size of a tile or tiles indicated by the given key
Parameters
----------
key : int or slice or tuple
which tiles to get
"""
arb_slices = self.__get_tile_slices(key)
inds = [sl.stop - sl.start for sl in arb_slices]
return tuple(inds)
[docs]
def __setitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...]],
value: Union[int, float, torch.Tensor],
) -> None:
"""
Set the values of a tile
Parameters
----------
key : int or Tuple or Slice
Key which identifies the tile/s to get
value : int or torch.Tensor
Value to be set on the tile
Examples
--------
see getitem function for this class
"""
if not isinstance(key, (tuple, slice, int, torch.Tensor)):
raise TypeError(f"key type not supported: {type(key)}")
if not isinstance(value, (torch.Tensor, int, float)):
raise TypeError(f"value type not supported: {type(value)}")
# todo: is it okay for cross-split setting? this can be problematic,
# but it is fine if the data shapes match up
if self.__DNDarray.comm.rank not in self.tile_locations[key]:
return None
# this will set the tile values using the torch setitem function
arr = self.__getitem__(key)
arr.__setitem__(slice(0, None), value)
[docs]
def get_subarray_params(
self, from_axis: int, to_axis: int
) -> List[Tuple[List[int], List[int], List[int]]]:
"""Create subarray types of the local array along a new split axis. For use with Alltoallw.
Based on the work by Dalcin et al. (https://arxiv.org/abs/1804.09536)
Return type is a list of tuples, each tuple containing the shape of the local array, the shape of the subarray, and the start index of the subarray.
Parameters
----------
from_axis : int
Current split axis of global array.
to_axis : int
New split axis of of subarrays array.
"""
arr = self.__DNDarray
world_size = arr.comm.Get_size()
gshape = arr.gshape
from_shape = list(gshape)
from_shape[from_axis] = int(self.tile_dimensions[from_axis][arr.comm.rank].item())
subsizes = from_shape
substarts = [0] * len(from_shape)
tile_dimensions = self.tile_dimensions[to_axis].to(torch.int64).tolist()
tile_starts = [0] + self.tile_ends_g[to_axis][:-1].to(torch.int64).tolist()
subarray_param_list = []
lshape = from_shape.copy()
for i in range(world_size):
chunk_size = tile_dimensions[i]
chunk_start = tile_starts[i]
subsizes[to_axis] = chunk_size
substarts[to_axis] = chunk_start
subarray_param_list.append((lshape, subsizes.copy(), substarts.copy()))
return subarray_param_list
[docs]
class SquareDiagTiles:
"""
Generate the tile map and the other objects which may be useful.
The tiles generated here are based of square tiles along the diagonal. The size of these
tiles along the diagonal dictate the divisions across all processes. If
``gshape[0]>>gshape[1]`` then there will be extra tiles generated below the diagonal.
If ``gshape[0]`` is close to ``gshape[1]``, then the last tile (as well as the other tiles which
correspond with said tile) will be extended to cover the whole array. However, extra tiles
are not generated above the diagonal in the case that ``gshape[0]<<gshape[1]``.
Parameters
----------
arr : DNDarray
The array to be tiled
tiles_per_proc : int, optional
The number of divisions per process
Default: 2
Attributes
----------
__col_per_proc_list : List
List is length of the number of processes, each element has the number of tile
columns on the process whos rank equals the index
__DNDarray: DNDarray
The whole DNDarray
__lshape_map : torch.Tensor
``unit -> [rank, row size, column size]``
Tensor filled with the shapes of the local tensors
__tile_map : torch.Tensor
``units -> row, column, start index in each direction, process``
Tensor filled with the global indices of the generated tiles
__row_per_proc_list : List
List is length of the number of processes, each element has the number of tile
rows on the process whos rank equals the index
Warnings
--------
The generation of these tiles may unbalance the original ``DNDarray``!
Notes
-----
This tiling scheme is intended for use with the :func:`~heat.core.linalg.qr.qr` function.
"""
def __init__(self, arr: DNDarray, tiles_per_proc: int = 2) -> None: # noqa: D107
# lshape_map -> rank (int), lshape (tuple of the local lshape, self.lshape)
if not isinstance(arr, DNDarray):
raise TypeError(f"arr must be a DNDarray, is currently a {type(self)}")
if not isinstance(tiles_per_proc, int):
raise TypeError(f"tiles_per_proc must be an int, is currently a {type(self)}")
if tiles_per_proc < 1:
raise ValueError(f"Tiles per process must be >= 1, currently: {tiles_per_proc}")
if len(arr.shape) != 2:
raise ValueError(f"Arr must be 2 dimensional, current shape {arr.shape}")
lshape_map = arr.create_lshape_map(force_check=True)
# if there is only one element of the diagonal on the next process
d = 1 if tiles_per_proc <= 2 else tiles_per_proc - 1
redist = torch.where(
torch.cumsum(lshape_map[..., arr.split], dim=0) >= arr.gshape[arr.split - 1] - d
)[0]
if redist.numel() > 0 and arr.gshape[0] > arr.gshape[1] and redist[0] != arr.comm.size - 1:
target_map = lshape_map.clone()
target_map[redist[0]] += d
target_map[redist[0] + 1] -= d
arr.redistribute_(lshape_map=lshape_map, target_map=target_map)
row_per_proc_list = [tiles_per_proc] * arr.comm.size
last_diag_pr, col_per_proc_list, col_inds, tile_columns = self.__create_cols(
arr, lshape_map, tiles_per_proc
)
# need to adjust the lshape if the splits overlap
if arr.split == 0 and tiles_per_proc == 1:
# if the split is 0 and the number of tiles per proc is 1
# then the local data needs to be redistributed to fit the full diagonal on as many
# processes as possible
# if any(lshape_map[..., arr.split] == 1):
(
last_diag_pr,
col_per_proc_list,
col_inds,
tile_columns,
) = self.__adjust_lshape_sp0_1tile(arr, col_inds, lshape_map, tiles_per_proc)
# re-test for empty processes and remove empty rows
empties = torch.where(lshape_map[..., 0] == 0)[0]
if empties.numel() > 0:
# need to remove the entry in the rows per process
for e in empties:
row_per_proc_list[e] = 0
row_inds = list(col_inds)
# set the row indices to be the same for all of the column indices
# (however many there are)
if arr.split == 0 and arr.gshape[0] < arr.gshape[1]:
# need to adjust the very last tile to be the remaining
col_inds[-1] = arr.gshape[1] - sum(col_inds[:-1])
# if there is too little data on the last tile then combine them
if arr.split == 0 and last_diag_pr < arr.comm.size - 1:
# these conditions imply that arr.gshape[0] > arr.gshape[1] (assuming balanced)
self.__adjust_last_row_sp0_m_ge_n(
arr, lshape_map, last_diag_pr, row_inds, row_per_proc_list, tile_columns
)
if arr.split == 0 and arr.gshape[0] > arr.gshape[1]:
# adjust the last row to have the
self.__def_end_row_inds_sp0_m_ge_n(
arr, row_inds, last_diag_pr, tiles_per_proc, lshape_map
)
if arr.split == 1 and arr.gshape[0] < arr.gshape[1]:
self.__adjust_cols_sp1_m_ls_n(
arr, col_per_proc_list, last_diag_pr, col_inds, lshape_map
)
if arr.split == 1 and arr.gshape[0] > arr.gshape[1]:
# add extra rows if there is place below the diagonal for split == 1
# adjust the very last tile to be the remaining
self.__last_tile_row_adjust_sp1(arr, row_inds)
# need to remove blank rows for arr.gshape[0] < arr.gshape[1]
if arr.gshape[0] < arr.gshape[1]:
row_inds_hold = []
for i in torch.nonzero(
input=torch.tensor(row_inds, device=arr.larray.device), as_tuple=False
).flatten():
row_inds_hold.append(row_inds[i.item()])
row_inds = row_inds_hold
tile_map = torch.zeros(
[len(row_inds), len(col_inds), 3], dtype=torch.int, device=arr.larray.device
)
# if arr.split == 0: # adjust the 1st dim to be the cumsum
col_inds = [0] + col_inds[:-1]
col_inds = torch.tensor(col_inds, device=arr.larray.device).cumsum(dim=0)
# if arr.split == 1: # adjust the 0th dim to be the cumsum
row_inds = [0] + row_inds[:-1]
row_inds = torch.tensor(row_inds, device=arr.larray.device).cumsum(dim=0)
for num, c in enumerate(col_inds): # set columns
tile_map[:, num, 1] = c
for num, r in enumerate(row_inds): # set rows
tile_map[num, :, 0] = r
for i in range(arr.comm.size):
st = sum(row_per_proc_list[:i])
sp = st + row_per_proc_list[i]
tile_map[..., 2][st:sp] = i
# to adjust if the last process has more tiles
i = arr.comm.size - 1
tile_map[..., 2][sum(row_per_proc_list[:i]) :] = i
if arr.split == 1:
st = 0
for pr, cols in enumerate(col_per_proc_list):
tile_map[:, st : st + cols, 2] = pr
st += cols
for c, i in enumerate(row_per_proc_list):
try:
row_per_proc_list[c] = i.item()
except AttributeError:
pass
for c, i in enumerate(col_per_proc_list):
try:
col_per_proc_list[c] = i.item()
except AttributeError:
pass
self.__DNDarray = arr
self.__col_per_proc_list = (
col_per_proc_list if arr.split == 1 else [len(col_inds)] * len(col_per_proc_list)
)
self.__lshape_map = lshape_map
self.__last_diag_pr = last_diag_pr.item()
self.__row_per_proc_list = (
row_per_proc_list if arr.split == 0 else [len(row_inds)] * len(row_per_proc_list)
)
self.__tile_map = tile_map
self.__row_inds = list(row_inds)
self.__col_inds = list(col_inds)
arr.__lshape_map = None
@staticmethod
def __adjust_cols_sp1_m_ls_n(
arr: DNDarray,
col_per_proc_list: List[int, ...],
last_diag_pr: int,
col_inds: List[int, ...],
lshape_map: torch.Tensor,
) -> None:
"""
Add more columns after the diagonal ends if ``m<n`` and ``arr.split==1``
"""
# need to add to col inds with the rest of the columns
tile_columns = sum(col_per_proc_list)
r = last_diag_pr + 1
for _ in range(len(col_inds), tile_columns):
col_inds.append(lshape_map[r, 1])
r += 1
# if the 1st dim is > 0th dim then in split=1 the cols need to be extended
col_proc_ind = torch.cumsum(
torch.tensor(col_per_proc_list, device=arr.larray.device), dim=0
)
for pr in range(arr.comm.size):
lshape_cumsum = torch.cumsum(lshape_map[..., 1], dim=0)
col_cumsum = torch.cumsum(torch.tensor(col_inds, device=arr.larray.device), dim=0)
diff = lshape_cumsum[pr] - col_cumsum[col_proc_ind[pr] - 1]
if diff > 0 and pr <= last_diag_pr:
col_per_proc_list[pr] += 1
col_inds.insert(col_proc_ind[pr], diff)
if pr > last_diag_pr and diff > 0:
col_inds.insert(col_proc_ind[pr], diff)
@staticmethod
def __adjust_last_row_sp0_m_ge_n(
arr: DNDarray,
lshape_map: torch.Tensor,
last_diag_pr: int,
row_inds: List[int, ...],
row_per_proc_list: List[int, ...],
tile_columns: int,
) -> None:
"""
Need to adjust the size of last row if ``arr.split==0`` and the diagonal ends before the
last tile. This should only be run if ``arr,split==0`` and ``last_diag_pr<arr.comm.size-1``.
"""
# need to find the amount of data after the diagonal
lshape_cumsum = torch.cumsum(lshape_map[..., 0], dim=0)
diff = lshape_cumsum[last_diag_pr] - arr.gshape[1]
if diff > torch.true_divide(lshape_map[last_diag_pr, 0], 2): # todo: tune this?
# if the shape diff is > half the data on the process
# then add a row after the diagonal, todo: is multiple rows faster?
row_inds.insert(tile_columns, diff)
row_per_proc_list[last_diag_pr] += 1
else:
# if the diff is < half the data on the process
# then extend the last row inds to be the end of the process
row_inds[tile_columns - 1] += diff
@staticmethod
def __adjust_lshape_sp0_1tile(
arr: DNDarray, col_inds: List[int, ...], lshape_map: torch.Tensor, tiles_per_proc: int
) -> None:
"""
If the split is 0 and the number of tiles per proc is 1 then the local data may need to be
redistributed to fit the full diagonal on as many processes as possible. If there is a
process where there is only 1 element, this function will adjust the ``lshape_map`` then
redistribute ``arr`` so that there is not a single diagonal element on one process
"""
def adjust_lshape(lshape_mapi, pri, cnti):
if lshape_mapi[..., 0][pri] < cnti:
h = cnti - lshape_mapi[..., 0][pri]
lshape_mapi[..., 0][pri] += h
lshape_mapi[..., 0][pri + 1] -= h
for cnt in col_inds[:-1]: # only need to loop until the second to last one
for pr in range(arr.comm.size - 1):
adjust_lshape(lshape_map, pr, cnt)
negs = torch.where(lshape_map[..., 0] < 0)[0]
if negs.numel() > 0:
for n in negs:
lshape_map[n - 1, 0] += lshape_map[n, 0]
lshape_map[n, 0] = 0
arr.redistribute_(target_map=lshape_map)
last_diag_pr, col_per_proc_list, col_inds, tile_columns = SquareDiagTiles.__create_cols(
arr, lshape_map, tiles_per_proc
)
return last_diag_pr, col_per_proc_list, col_inds, tile_columns
@staticmethod
def __create_cols(
arr: DNDarray, lshape_map: torch.Tensor, tiles_per_proc: int
) -> Tuple[torch.Tensor, List[int, ...], List[int, ...], torch.Tensor]:
"""
Calculates the last diagonal process, then creates a list of the number of tile columns per
process, then calculates the starting indices of the columns. Also returns the number of tile
columns.
Parameters
----------
arr : DNDarray
DNDarray for which to find the tile columns for
lshape_map : torch.Tensor
The map of the local shapes (for more info see: :func:`~heat.core.dndarray.DNDarray.create_lshape_map`)
tiles_per_proc : int
The number of divisions per process
"""
last_tile_cols = tiles_per_proc
last_dia_pr = torch.where(lshape_map[..., arr.split].cumsum(dim=0) >= min(arr.gshape))[0][0]
# adjust for small blocks on the last diag pr:
last_pr_minus1 = last_dia_pr - 1 if last_dia_pr > 0 else 0
rem_cols_last_pr = abs(
min(arr.gshape) - lshape_map[..., arr.split].cumsum(dim=0)[last_pr_minus1]
)
# this is the number of rows/columns after the last diagonal on the last diagonal pr
try:
num_after_diag = torch.div(rem_cols_last_pr, last_tile_cols, rounding_mode="floor")
except TypeError:
num_after_diag = torch.floor_divide(rem_cols_last_pr, last_tile_cols)
while 1 < num_after_diag < 2:
# todo: determine best value for this (prev at 2)
# if there cannot be tiles formed which are at list ten items larger than 2
# then need to reduce the number of tiles
last_tile_cols -= 1
if last_tile_cols == 1:
break
# create lists of columns and rows for each process
col_per_proc_list = [tiles_per_proc] * (last_dia_pr.item() + 1)
col_per_proc_list[-1] = last_tile_cols
if last_dia_pr < arr.comm.size - 1 and arr.split == 1:
# this is the case that the gshape[1] >> gshape[0]
col_per_proc_list.extend([1] * (arr.comm.size - last_dia_pr - 1).item())
# need to determine the proper number of tile rows/columns
tile_columns = tiles_per_proc * last_dia_pr + last_tile_cols
diag_crossings = lshape_map[..., arr.split].cumsum(dim=0)[: last_dia_pr + 1]
diag_crossings[-1] = (
diag_crossings[-1] if diag_crossings[-1] <= min(arr.gshape) else min(arr.gshape)
)
dev = arr.larray.device
diag_crossings = torch.cat((torch.tensor([0], device=dev), diag_crossings), dim=0).tolist()
# create the tile columns sizes, saved to list
col_inds = []
for col in range(tile_columns.item()):
try:
off = torch.div(col, tiles_per_proc, rounding_mode="floor").to(dev)
except TypeError:
off = torch.floor_divide(col, tiles_per_proc).to(dev)
_, lshape, _ = arr.comm.chunk(
[diag_crossings[off + 1] - diag_crossings[off]],
0,
rank=int(col % tiles_per_proc),
w_size=tiles_per_proc if off != last_dia_pr else last_tile_cols,
)
col_inds.append(lshape[0])
return last_dia_pr, col_per_proc_list, col_inds, tile_columns
@staticmethod
def __def_end_row_inds_sp0_m_ge_n(
arr: DNDarray,
row_inds: List[int, ...],
last_diag_pr: int,
tiles_per_proc: int,
lshape_map: torch.Tensor,
) -> None:
"""
Adjust the rows on the processes which are greater than the last diagonal processs to have
rows which are chunked evenly into ``tiles_per_proc`` rows.
"""
nz = torch.nonzero(
input=torch.tensor(row_inds, device=arr.larray.device) == 0, as_tuple=False
)
lp_map = lshape_map.tolist()
for i, t in itertools.product(
range(last_diag_pr.item() + 1, arr.comm.size), range(tiles_per_proc)
):
_, lshape, _ = arr.comm.chunk(lp_map[i], 0, rank=t, w_size=tiles_per_proc)
# row_inds[nz[0].item()] = lshape[0]
if row_inds[-1] == 0:
row_inds[-1] = lshape[0]
else:
row_inds.append(lshape[0])
nz = nz[1:]
@staticmethod
def __last_tile_row_adjust_sp1(arr: DNDarray, row_inds: List[int, ...]) -> None:
"""
Add extra row/s if there is space below the diagonal (``split=1``)
"""
if arr.gshape[0] - arr.gshape[1] > 10: # todo: determine best value for this
# use chunk and a loop over the however many tiles are desired
num_ex_row_tiles = 1 # todo: determine best value for this
while (arr.gshape[0] - arr.gshape[1]) // num_ex_row_tiles < 2:
num_ex_row_tiles -= 1
for i in range(num_ex_row_tiles):
_, lshape, _ = arr.comm.chunk(
(arr.gshape[0] - arr.gshape[1],), 0, rank=i, w_size=num_ex_row_tiles
)
row_inds.append(lshape[0])
else:
# if there is no place for multiple tiles, combine the remainder with the last row
row_inds[-1] = arr.gshape[0] - sum(row_inds[:-1])
@property
def arr(self) -> DNDarray:
"""
Returns the ``DNDarray`` for which the tiles are defined on
"""
return self.__DNDarray
@property
def col_indices(self) -> List[int, ...]:
"""
Returns a list containing the indices of the tile columns
"""
return self.__col_inds
@property
def lshape_map(self) -> torch.Tensor:
"""
Returns the map of the lshape tuples for the ``DNDarray`` given.
Units are ``(rank, lshape)`` (tuple of the local shape)
"""
return self.__lshape_map
@property
def last_diagonal_process(self) -> int:
"""
Returns the rank of the last process with diagonal elements
"""
return self.__last_diag_pr
@property
def row_indices(self) -> List[int, ...]:
"""
Returns a list containing the indices of the tile rows
"""
return self.__row_inds
@property
def tile_columns(self) -> int:
"""
Returns the number of tile columns
"""
return len(self.__col_inds)
@property
def tile_columns_per_process(self) -> List[int, ...]:
"""
Returns a list containing the number of columns on all processes
"""
return self.__col_per_proc_list
@property
def tile_map(self) -> torch.Tensor:
"""
Returns tile_map which contains the sizes of the tiles
units are ``(row, column, start index in each direction, process)``
Examples
--------
>>> a = ht.zeros((12, 10), split=0)
>>> a_tiles = tiling.SquareDiagTiles(a, tiles_per_proc=2)
>>> print(a_tiles.tile_map)
[(0 & 1)/1] tensor([[[0, 0, 0],
[(0 & 1)/1] [0, 3, 0],
[(0 & 1)/1] [0, 6, 0],
[(0 & 1)/1] [0, 8, 0]],
[(0 & 1)/1]
[(0 & 1)/1] [[3, 0, 0],
[(0 & 1)/1] [3, 3, 0],
[(0 & 1)/1] [3, 6, 0],
[(0 & 1)/1] [3, 8, 0]],
[(0 & 1)/1]
[(0 & 1)/1] [[6, 0, 1],
[(0 & 1)/1] [6, 3, 1],
[(0 & 1)/1] [6, 6, 1],
[(0 & 1)/1] [6, 8, 1]],
[(0 & 1)/1]
[(0 & 1)/1] [[8, 0, 1],
[(0 & 1)/1] [8, 3, 1],
[(0 & 1)/1] [8, 6, 1],
[(0 & 1)/1] [8, 8, 1]]], dtype=torch.int32)
>>> print(a_tiles.tile_map.shape)
[0/1] torch.Size([4, 4, 3])
[1/1] torch.Size([4, 4, 3])
"""
return self.__tile_map
@property
def tile_rows(self) -> int:
"""
Returns the number of tile rows
"""
return len(self.__row_inds)
@property
def tile_rows_per_process(self) -> List[int, ...]:
"""
Returns a list containing the number of rows on all processes
"""
return self.__row_per_proc_list
[docs]
def get_start_stop(
self, key: Union[int, slice, Tuple[int, slice, ...]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns the start and stop indices in form of ``(dim0 start, dim0 stop, dim1 start, dim1 stop)``
which correspond to the tile/s which corresponds to the given key. The key MUST use global indices.
Parameters
----------
key : int or Tuple or List or slice
Indices to select the tile
STRIDES ARE NOT ALLOWED, MUST BE GLOBAL INDICES
Examples
--------
>>> a = ht.zeros((12, 10), split=0)
>>> a_tiles = ht.tiling.SquareDiagTiles(a, tiles_per_proc=2) # type: tiling.SquareDiagTiles
>>> print(a_tiles.get_start_stop(key=(slice(0, 2), 2)))
[0/1] (tensor(0), tensor(6), tensor(6), tensor(8))
[1/1] (tensor(0), tensor(6), tensor(6), tensor(8))
>>> print(a_tiles.get_start_stop(key=(0, 2)))
[0/1] (tensor(0), tensor(3), tensor(6), tensor(8))
[1/1] (tensor(0), tensor(3), tensor(6), tensor(8))
>>> print(a_tiles.get_start_stop(key=2))
[0/1] (tensor(0), tensor(2), tensor(0), tensor(10))
[1/1] (tensor(0), tensor(2), tensor(0), tensor(10))
>>> print(a_tiles.get_start_stop(key=(3, 3)))
[0/1] (tensor(2), tensor(6), tensor(8), tensor(10))
[1/1] (tensor(2), tensor(6), tensor(8), tensor(10))
"""
split = self.__DNDarray.split
pr = self.tile_map[key][..., 2].unique()
if pr.numel() > 1:
raise ValueError(f"Tile/s must be located on one process. currently on: {pr}")
row_inds = self.row_indices + [self.__DNDarray.gshape[0]]
col_inds = self.col_indices + [self.__DNDarray.gshape[1]]
row_start = row_inds[sum(self.tile_rows_per_process[:pr]) if split == 0 else 0]
col_start = col_inds[sum(self.tile_columns_per_process[:pr]) if split == 1 else 0]
if isinstance(key, int):
key = [key]
else:
key = list(key)
if len(key) == 1:
key.append(slice(0, None))
key = list(key)
if isinstance(key[0], int):
st0 = row_inds[key[0]] - row_start
sp0 = row_inds[key[0] + 1] - row_start
elif isinstance(key[0], slice):
start = row_inds[key[0].start] if key[0].start is not None else 0
stop = row_inds[key[0].stop] if key[0].stop is not None else row_inds[-1]
st0, sp0 = start - row_start, stop - row_start
if isinstance(key[1], int):
st1 = col_inds[key[1]] - col_start
sp1 = col_inds[key[1] + 1] - col_start
elif isinstance(key[1], slice):
start = col_inds[key[1].start] if key[1].start is not None else 0
stop = col_inds[key[1].stop] if key[1].stop is not None else col_inds[-1]
st1, sp1 = start - col_start, stop - col_start
return st0, sp0, st1, sp1
[docs]
def __getitem__(self, key: Union[int, slice, Tuple[int, slice, ...]]) -> torch.Tensor:
"""
Returns a local selection of the DNDarray corresponding to the tile/s desired
Standard getitem function for the tiles. The returned item is a view of the original
DNDarray, operations which are done to this view will change the original array.
**STRIDES ARE NOT AVAILABLE, NOR ARE CROSS-SPLIT SLICES**
Parameters
----------
key : int, slice, tuple
indices of the tile/s desired
Examples
--------
>>> a = ht.zeros((12, 10), split=0)
>>> a_tiles = tiling.SquareDiagTiles(a, tiles_per_proc=2) # type: tiling.SquareDiagTiles
>>> print(a_tiles[2, 3])
[0/1] None
[1/1] tensor([[0., 0.],
[1/1] [0., 0.]])
>>> print(a_tiles[2])
[0/1] None
[1/1] tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1/1] [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
>>> print(a_tiles[0:2, 1])
[0/1] tensor([[0., 0., 0.],
[0/1] [0., 0., 0.],
[0/1] [0., 0., 0.],
[0/1] [0., 0., 0.],
[0/1] [0., 0., 0.],
[0/1] [0., 0., 0.]])
[1/1] None
"""
arr = self.__DNDarray
tile_map = self.__tile_map
local_arr = arr.larray
if not isinstance(key, (int, tuple, slice)):
raise TypeError(f"key must be an int, tuple, or slice, is currently {type(key)}")
involved_procs = tile_map[key][..., 2].unique()
if involved_procs.nelement() == 1 and involved_procs == arr.comm.rank:
st0, sp0, st1, sp1 = self.get_start_stop(key=key)
return local_arr[st0:sp0, st1:sp1]
elif involved_procs.nelement() > 1:
raise ValueError("Slicing across splits is not allowed")
else:
return None
[docs]
def local_get(self, key: Union[int, slice, Tuple[int, slice, ...]]) -> torch.Tensor:
"""
Returns the local tile/s corresponding to the key given
Getitem routing using local indices, converts to global indices then uses getitem
Parameters
----------
key : int, slice, tuple, list
Indices of the tile/s desired.
If the stop index of a slice is larger than the end will be adjusted to the maximum
allowed
Examples
--------
See local_set function.
"""
rank = self.__DNDarray.comm.rank
key = self.local_to_global(key=key, rank=rank)
return self.__getitem__(key)
[docs]
def local_set(
self, key: Union[int, slice, Tuple[int, slice, ...]], value: Union[int, float, torch.Tensor]
):
"""
Setitem routing to set data to a local tile (using local indices)
Parameters
----------
key : int or slice or Tuple[int,...]
Indices of the tile/s desired
If the stop index of a slice is larger than the end will be adjusted to the maximum
allowed
value : torch.Tensor or int or float
Data to be written to the tile
Examples
--------
>>> a = ht.zeros((11, 10), split=0)
>>> a_tiles = tiling.SquareDiagTiles(a, tiles_per_proc=2) # type: tiling.SquareDiagTiles
>>> local = a_tiles.local_get(key=slice(None))
>>> a_tiles.local_set(
... key=slice(None), value=torch.arange(local.numel()).reshape(local.shape)
... )
>>> print(a.larray)
[0/1] tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[0/1] [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
[0/1] [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
[0/1] [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
[0/1] [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
[0/1] [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.]])
[1/1] tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[1/1] [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
[1/1] [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
[1/1] [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
[1/1] [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.]])
>>> a.lloc[:] = 0
>>> a_tiles.local_set(key=(0, 2), value=10)
[0/1] tensor([[ 0., 0., 0., 0., 0., 0., 10., 10., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 10., 10., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 10., 10., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
[1/1] tensor([[ 0., 0., 0., 0., 0., 0., 10., 10., 0., 0.],
[1/1] [ 0., 0., 0., 0., 0., 0., 10., 10., 0., 0.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
>>> a_tiles.local_set(key=(slice(None), 1), value=10)
[0/1] tensor([[ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[0/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.]])
[1/1] tensor([[ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.],
[1/1] [ 0., 0., 0., 10., 10., 10., 0., 0., 0., 0.]])
"""
rank = self.__DNDarray.comm.rank
key = self.local_to_global(key=key, rank=rank)
self.__getitem__(tuple(key)).__setitem__(slice(0, None), value)
[docs]
def local_to_global(
self, key: Union[int, slice, Tuple[int, slice, ...]], rank: int
) -> Tuple[int, slice, ...]:
"""
Convert local indices to global indices
Parameters
----------
key : int or slice or Tuple or List
Indices of the tile/s desired.
If the stop index of a slice is larger than the end will be adjusted to the maximum
allowed
rank : int
Process rank
Examples
--------
>>> a = ht.zeros((11, 10), split=0)
>>> a_tiles = tiling.SquareDiagTiles(a, tiles_per_proc=2) # type: tiling.SquareDiagTiles
>>> rank = a.comm.rank
>>> print(a_tiles.local_to_global(key=(slice(None), 1), rank=rank))
[0/1] (slice(0, 2, None), 1)
[1/1] (slice(2, 4, None), 1)
>>> print(a_tiles.local_to_global(key=(0, 2), rank=0))
[0/1] (0, 2)
[1/1] (0, 2)
>>> print(a_tiles.local_to_global(key=(0, 2), rank=1))
[0/1] (2, 2)
[1/1] (2, 2)
"""
arr = self.__DNDarray
if isinstance(key, (int, slice)):
key = [key, slice(0, None)]
else:
key = list(key)
if arr.split == 0:
# need to adjust key[0] to be only on the local tensor
prev_rows = sum(self.__row_per_proc_list[:rank])
loc_rows = self.__row_per_proc_list[rank]
if isinstance(key[0], int):
key[0] += prev_rows
elif isinstance(key[0], slice):
start = key[0].start + prev_rows if key[0].start is not None else prev_rows
stop = key[0].stop + prev_rows if key[0].stop is not None else prev_rows + loc_rows
stop = stop if stop - start < loc_rows else start + loc_rows
key[0] = slice(start, stop)
if arr.split == 1:
loc_cols = self.__col_per_proc_list[rank]
prev_cols = sum(self.__col_per_proc_list[:rank])
# need to adjust key[0] to be only on the local tensor
# need the number of columns *before* the process
if isinstance(key[1], int):
key[1] += prev_cols
elif isinstance(key[1], slice):
start = key[1].start + prev_cols if key[1].start is not None else prev_cols
stop = key[1].stop + prev_cols if key[1].stop is not None else prev_cols + loc_cols
stop = stop if stop - start < loc_cols else start + loc_cols
key[1] = slice(start, stop)
return tuple(key)
[docs]
def match_tiles(self, tiles_to_match: SquareDiagTiles) -> None:
"""
Function to match the tile sizes of another tile map
Parameters
----------
tiles_to_match : SquareDiagTiles
The tiles which should be matched by the current tiling scheme
Notes
-----
This function overwrites most, if not all, of the elements of this class. Intended for use with the Q matrix,
to match the tiling of a/R. For this to work properly it is required that the 0th dim of both matrices is equal
"""
if not isinstance(tiles_to_match, SquareDiagTiles):
raise TypeError(
f"tiles_to_match must be a SquareDiagTiles object, currently: {type(tiles_to_match)}"
)
base_dnd = self.__DNDarray
match_dnd = tiles_to_match.__DNDarray
# this map will take the same tile row and column sizes up to the last diagonal row/column
# the last row/column is determined by the number of rows/columns on the non-split dimension
if base_dnd.split == match_dnd.split == 0:
# this implies that the gshape[0]'s are equal
# rows are the exact same, and the cols are also equal to the rows (square matrix)
base_dnd.redistribute_(lshape_map=self.lshape_map, target_map=tiles_to_match.lshape_map)
self.__row_per_proc_list = tiles_to_match.__row_per_proc_list.copy()
self.__col_per_proc_list = [tiles_to_match.tile_rows] * len(self.__row_per_proc_list)
self.__row_inds = (
tiles_to_match.__row_inds.copy()
if base_dnd.gshape[0] >= base_dnd.gshape[1]
else tiles_to_match.__col_inds.copy()
)
self.__col_inds = (
tiles_to_match.__row_inds.copy()
if base_dnd.gshape[0] >= base_dnd.gshape[1]
else tiles_to_match.__col_inds.copy()
)
# self.__tile_rows = tiles_to_match.__tile_rows
# self.__tile_columns = tiles_to_match.__tile_rows
self.__tile_map = torch.zeros(
(self.tile_rows, self.tile_columns, 3),
dtype=torch.int,
device=match_dnd.larray.device,
)
for i in range(self.tile_rows):
self.__tile_map[..., 0][i] = self.__row_inds[i]
for i in range(self.tile_columns):
self.__tile_map[..., 1][:, i] = self.__col_inds[i]
for i in range(self.arr.comm.size - 1):
st = sum(self.__row_per_proc_list[:i])
sp = st + self.__row_per_proc_list[i]
self.__tile_map[..., 2][st:sp] = i
# to adjust if the last process has more tiles
i = self.arr.comm.size - 1
self.__tile_map[..., 2][sum(self.__row_per_proc_list[:i]) :] = i
if base_dnd.split == 0 and match_dnd.split == 1:
# rows determine the q sizes -> cols = rows
self.__col_inds = (
tiles_to_match.__row_inds.copy()
if base_dnd.gshape[0] <= base_dnd.gshape[1]
else tiles_to_match.__col_inds.copy()
)
self.__row_inds = (
tiles_to_match.__row_inds.copy()
if base_dnd.gshape[0] <= base_dnd.gshape[1]
else tiles_to_match.__col_inds.copy()
)
rows_per = [x for x in self.__col_inds if x < base_dnd.shape[0]]
# self.__tile_rows = len(rows_per)
# self.__tile_columns = self.tile_rows
target_0 = tiles_to_match.lshape_map[..., 1][: tiles_to_match.last_diagonal_process]
end_tag0 = base_dnd.shape[0] - sum(target_0[: tiles_to_match.last_diagonal_process])
end_tag0 = [end_tag0] + [0] * (
base_dnd.comm.size - 1 - tiles_to_match.last_diagonal_process
)
target_0 = torch.cat(
(target_0, torch.tensor(end_tag0, device=target_0.device, dtype=target_0.dtype)),
dim=0,
)
targe_map = self.lshape_map.clone()
targe_map[..., 0] = target_0
target_0_c = torch.cumsum(target_0, dim=0)
self.__row_per_proc_list = []
st = 0
rows_per = torch.tensor(
rows_per + [base_dnd.shape[0]], device=tiles_to_match.arr.larray.device
)
for i in range(base_dnd.comm.size):
# get the amount of data on each process, get the number of rows with
# indices which are between the start and stop
self.__row_per_proc_list.append(
torch.where((st < rows_per) & (rows_per <= target_0_c[i]))[0].numel()
)
st = target_0_c[i]
base_dnd.redistribute_(lshape_map=self.lshape_map, target_map=targe_map)
self.__tile_map = torch.zeros(
(self.tile_rows, self.tile_columns, 3),
dtype=torch.int,
device=tiles_to_match.arr.larray.device,
)
for i in range(self.tile_rows):
self.__tile_map[..., 0][i] = self.__row_inds[i]
for i in range(self.tile_columns):
self.__tile_map[..., 1][:, i] = self.__col_inds[i]
for i in range(self.arr.comm.size):
st = sum(self.__row_per_proc_list[:i])
sp = st + self.__row_per_proc_list[i]
self.__tile_map[..., 2][st:sp] = i
# to adjust if the last process has more tiles
i = self.arr.comm.size - 1
self.__tile_map[..., 2][sum(self.__row_per_proc_list[:i]) :] = i
self.__col_per_proc_list = [self.tile_columns] * base_dnd.comm.size
self.__last_diag_pr = base_dnd.comm.size - 1
self.__DNDarray.__lshape_map = None
tiles_to_match.__DNDarray.__lshape_map = None
[docs]
def __setitem__(
self, key: Union[int, slice, Tuple[int, slice, ...]], value: Union[int, float, torch.Tensor]
) -> None:
"""
Item setter,
uses the torch item setter and the getitem routines to set the values of the original array
(arr in __init__)
Parameters
----------
key : int or slice or Tuple[int,...]
Tile indices to identify the target tiles
value : int or torch.Tensor
Values to be set
Example
-------
>>> a = ht.zeros((12, 10), split=0)
>>> a_tiles = tiling.SquareDiagTiles(a, tiles_per_proc=2) # type: tiling.SquareDiagTiles
>>> a_tiles[0:2, 2] = 11
>>> a_tiles[0, 0] = 22
>>> a_tiles[2] = 33
>>> a_tiles[3, 3] = 44
>>> print(a.larray)
[0/1] tensor([[22., 22., 22., 0., 0., 0., 11., 11., 0., 0.],
[0/1] [22., 22., 22., 0., 0., 0., 11., 11., 0., 0.],
[0/1] [22., 22., 22., 0., 0., 0., 11., 11., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 11., 11., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 11., 11., 0., 0.],
[0/1] [ 0., 0., 0., 0., 0., 0., 11., 11., 0., 0.]])
[1/1] tensor([[33., 33., 33., 33., 33., 33., 33., 33., 33., 33.],
[1/1] [33., 33., 33., 33., 33., 33., 33., 33., 33., 33.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 44., 44.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 44., 44.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 44., 44.],
[1/1] [ 0., 0., 0., 0., 0., 0., 0., 0., 44., 44.]])
"""
arr = self.__DNDarray
tile_map = self.__tile_map
if tile_map[key][..., 2].unique().numel() > 1:
raise ValueError("setting across splits is not allowed")
if arr.comm.rank == tile_map[key][..., 2].unique():
# this will set the tile values using the torch setitem function
arr = self.__getitem__(key)
arr.__setitem__(slice(0, None), value)