Source code for heat.core.printing

"""Allows to output DNDarrays to stdout."""

import builtins
import copy
import torch
from .communication import MPI_WORLD

from .dndarray import DNDarray

__all__ = ["get_printoptions", "global_printing", "local_printing", "print0", "set_printoptions"]


# set the default printing width to a 120
_DEFAULT_LINEWIDTH = 120
torch.set_printoptions(profile="default", linewidth=_DEFAULT_LINEWIDTH)
LOCAL_PRINT = False

# printing
__PREFIX = "DNDarray"
__INDENT = len(__PREFIX)


[docs] def get_printoptions() -> dict: """ Returns the currently configured printing options as key-value pairs. """ return copy.copy(torch._tensor_str.PRINT_OPTS.__dict__)
[docs] def local_printing() -> None: """ The builtin `print` function will now print the local PyTorch Tensor values for `DNDarrays` given as arguments. Examples -------- >>> x = ht.ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0) >>> ht.local_printing() [0/2]Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors >>> print(x) [0/2] [[ 0., 1., 2., 3., 4.], [0/2] [ 5., 6., 7., 8., 9.], [0/2] [10., 11., 12., 13., 14.], [0/2] [15., 16., 17., 18., 19.], [0/2] [20., 21., 22., 23., 24.]] [1/2] [[25., 26., 27., 28., 29.], [1/2] [30., 31., 32., 33., 34.], [1/2] [35., 36., 37., 38., 39.], [1/2] [40., 41., 42., 43., 44.], [1/2] [45., 46., 47., 48., 49.]] [2/2] [[50., 51., 52., 53., 54.], [2/2] [55., 56., 57., 58., 59.], [2/2] [60., 61., 62., 63., 64.], [2/2] [65., 66., 67., 68., 69.], [2/2] [70., 71., 72., 73., 74.]] """ global LOCAL_PRINT LOCAL_PRINT = True print0("Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors")
[docs] def global_printing() -> None: """ For `DNDarray`s, the builtin `print` function will gather all of the data, format it then print it on ONLY rank 0. Returns ------- None Examples -------- >>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0) >>> print(x) [0] DNDarray([[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.], [20., 21., 22., 23., 24.], [25., 26., 27., 28., 29.], [30., 31., 32., 33., 34.], [35., 36., 37., 38., 39.], [40., 41., 42., 43., 44.], [45., 46., 47., 48., 49.], [50., 51., 52., 53., 54.], [55., 56., 57., 58., 59.], [60., 61., 62., 63., 64.], [65., 66., 67., 68., 69.], [70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0) """ global LOCAL_PRINT if not LOCAL_PRINT: return LOCAL_PRINT = False print0( "Printing options set to GLOBAL. DNDarrays will be collected on process 0 before printing" )
[docs] def print0(*args, **kwargs) -> None: """ Wraps the builtin `print` function in such a way that it will only run the command on rank 0. If this is called with DNDarrays and local printing, only the data local to process 0 is printed. For more information see the examples. This function is also available as a builtin when importing heat. Examples -------- >>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0) >>> # GLOBAL PRINTING >>> ht.print0(x) [0] DNDarray([[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.], [20., 21., 22., 23., 24.], [25., 26., 27., 28., 29.], [30., 31., 32., 33., 34.], [35., 36., 37., 38., 39.], [40., 41., 42., 43., 44.], [45., 46., 47., 48., 49.], [50., 51., 52., 53., 54.], [55., 56., 57., 58., 59.], [60., 61., 62., 63., 64.], [65., 66., 67., 68., 69.], [70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0) >>> ht.local_printing() [0/2] Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors >>> print0(x) [0/2] [[ 0., 1., 2., 3., 4.], [0/2] [ 5., 6., 7., 8., 9.], [0/2] [10., 11., 12., 13., 14.], [0/2] [15., 16., 17., 18., 19.], [0/2] [20., 21., 22., 23., 24.]], device: cpu:0, split: 0 """ if not LOCAL_PRINT: args = list(args) for i in range(len(args)): if isinstance(args[i], DNDarray): args[i] = __str__(args[i]) args = tuple(args) if MPI_WORLD.rank == 0: print(*args, **kwargs)
builtins.print0 = print0
[docs] def set_printoptions( precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None ): """ Configures the printing options. List of items shamelessly taken from NumPy and PyTorch (thanks guys!). Parameters ---------- precision : int, optional Number of digits of precision for floating point output (default=4). threshold : int, optional Total number of array elements which trigger summarization rather than full `repr` string (default=1000). edgeitems : int, optional Number of array items in summary at beginning and end of each dimension (default=3). linewidth : int, optional The number of characters per line for the purpose of inserting line breaks (default = 80). profile : str, optional Sane defaults for pretty printing. Can override with any of the above options. Can be any one of `default`, `short`, `full`. sci_mode : bool, optional Enable (True) or disable (False) scientific notation. If None (default) is specified, the value is automatically inferred by HeAT. """ torch.set_printoptions(precision, threshold, edgeitems, linewidth, profile, sci_mode) # HeAT profiles will print a bit wider than PyTorch does if ( (profile == "default" and linewidth is None) or (profile == "short" and linewidth is None) or (profile == "full" and linewidth is None) ): torch._tensor_str.PRINT_OPTS.linewidth = _DEFAULT_LINEWIDTH
def __str__(dndarray) -> str: """ Computes a printable representation of the passed DNDarray. Parameters ---------- dndarray : DNDarray The array for which to obtain the corresponding string """ if LOCAL_PRINT: return ( torch._tensor_str._tensor_str(dndarray.larray, 0) + f", device: {dndarray.device}, split: {dndarray.split}" ) tensor_string = _tensor_str(dndarray, __INDENT + 1) if dndarray.comm.rank != 0: return "" return ( f"{__PREFIX}({tensor_string}, dtype=ht.{dndarray.dtype.__name__}, " f"device={dndarray.device}, split={dndarray.split})" ) def __repr__(dndarray) -> str: """ Returns a printable representation of the passed DNDarray. Unlike the __str__ method, which prints a representation targeted at users, this method targets developers by showing key internal parameters of the DNDarray. """ tensor_string = torch._tensor_str._tensor_str(dndarray.larray, __INDENT + 1) return f"DNDarray(MPI-rank: {dndarray.comm.rank}, Shape: {dndarray.shape}, Split: {dndarray.split}, Local Shape: {dndarray.lshape}, Device: {dndarray.device}, Dtype: {dndarray.dtype.__name__}, Data:\n{' ' * __INDENT} {tensor_string})" def _torch_data(dndarray, summarize) -> DNDarray: """ Extracts the data to be printed from the DNDarray in form of a torch tensor and returns it. Parameters ---------- dndarray : DNDarray The HeAT DNDarray to be printed. summarize : bool Flag indicating whether to print the full data or summarized, i.e. ellipsed, version of the data. """ if not dndarray.is_balanced(): dndarray.balance_() # data is not split, we can use it as is if dndarray.split is None or dndarray.comm.size == 1: data = dndarray.larray # split, but no summary required, we collect it elif not summarize: data = dndarray.copy().resplit_(None).larray # split, but summarized, collect the slices from all nodes and pass it on else: edgeitems = torch._tensor_str.PRINT_OPTS.edgeitems double_items = 2 * edgeitems ndims = dndarray.ndim data = dndarray.larray for i in range(ndims): # skip over dimensions that are smaller than twice the number of edge items to display if dndarray.gshape[i] <= double_items: continue # non-split dimension, can slice locally if i != dndarray.split: start_tensor = torch.index_select( data, i, torch.arange(edgeitems + 1, device=data.device) ) end_tensor = torch.index_select( data, i, torch.arange( dndarray.lshape[i] - edgeitems, dndarray.lshape[i], device=data.device ), ) data = torch.cat([start_tensor, end_tensor], dim=i) # split-dimension , need to respect the global offset elif i == dndarray.split and dndarray.gshape[i] > double_items: offset, _, _ = dndarray.comm.chunk(dndarray.gshape, i) if offset < edgeitems + 1: end = min(dndarray.lshape[i], edgeitems + 1 - offset) data = torch.index_select(data, i, torch.arange(end, device=data.device)) elif dndarray.gshape[i] - edgeitems < offset - dndarray.lshape[i]: global_start = dndarray.gshape[i] - edgeitems data = torch.index_select( data, i, torch.arange( max(0, global_start - offset), dndarray.lshape[i], device=data.device, ), ) # exchange data exchange_sizes = dndarray.comm.gather(torch.tensor(data.shape), root=0) if dndarray.comm.rank == 0: counts = tuple([s[dndarray.split] for s in exchange_sizes]) displs = (0,) + tuple(torch.cumsum(torch.tensor(counts), dim=0)[:-1]) recv_size = exchange_sizes[0].clone() recv_size[dndarray.split] = sum(counts) recv_buf = torch.empty(tuple(recv_size), dtype=data.dtype, device=data.device) recv_buf = (recv_buf, counts, displs) else: recv_buf = torch.empty(0) dndarray.comm.Gatherv(data, recv_buf, axis=dndarray.split, recv_axis=dndarray.split) if dndarray.comm.rank == 0: data = recv_buf[0] return data def _tensor_str(dndarray, indent: int) -> str: """ Computes a string representation of the passed DNDarray. Parameters ---------- dndarray: DNDarray The array for which to obtain the corresponding string indent: int The number of spaces the array content is indented. """ elements = dndarray.gnumel if elements == 0: return "[]" # we will recycle torch's printing features here # to do so, we slice up the torch data and forward it to torch internal printing mechanism summarize = elements > get_printoptions()["threshold"] torch_data = _torch_data(dndarray, summarize) if not dndarray.is_distributed(): # let torch handle formatting on non-distributed data # formatter gets too slow for even moderately large tensors return torch._tensor_str._tensor_str(torch_data, indent) formatter = torch._tensor_str._Formatter(torch_data) return torch._tensor_str._tensor_str_with_formatter(torch_data, indent, summarize, formatter)