"""Enables parallel I/O with data on disk."""
from __future__ import annotations
from functools import reduce
import glob
import operator
import os.path
from math import log10
from pathlib import Path
import numpy as np
import torch
import warnings
import fnmatch
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from . import devices
from . import factories
from . import types
from .communication import Communication, MPI, MPI_WORLD, sanitize_comm
from .dndarray import DNDarray
from .manipulations import hsplit, vsplit
from .statistics import max as smax, min as smin
from .stride_tricks import sanitize_axis
from .types import datatype
__VALID_WRITE_MODES = frozenset(["w", "a", "r+"])
__CSV_EXTENSION = frozenset([".csv"])
__HDF5_EXTENSIONS = frozenset([".h5", ".hdf5"])
__NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"])
__NETCDF_DIM_TEMPLATE = "{}_dim_{}"
__ZARR_EXTENSIONS = frozenset([".zarr"])
__all__ = [
"load",
"load_csv",
"save_csv",
"save",
"supports_hdf5",
"supports_netcdf",
"load_npy_from_path",
"supports_zarr",
]
def size_from_slice(size: int, s: slice) -> Tuple[int, int]:
"""
Determines the size of a slice object.
Parameters
----------
size: int
The size of the array the slice object is applied to.
s : slice
The slice object to determine the size of.
Returns
-------
int
The size of the sliced object.
int
The start index of the slice object.
"""
new_range = range(size)[s]
return len(new_range), new_range.start if len(new_range) > 0 else 0
try:
import netCDF4 as nc
except ImportError:
# netCDF4 support is optional
def supports_netcdf() -> bool:
"""
Returns ``True`` if Heat supports reading from and writing to netCDF4 files, ``False`` otherwise.
"""
return False
else:
# add functions to visible exports
__all__.extend(["load_netcdf", "save_netcdf"])
# determine netCDF's parallel I/O support
__nc_has_par = (
nc.__dict__.get("__has_parallel4_support__", False)
or nc.__dict__.get("__has_pnetcdf_support__", False)
or nc.__dict__.get("__has_nc_par__", False)
)
# warn the user about serial netcdf
if not __nc_has_par and MPI_WORLD.rank == 0:
warnings.warn(
"netCDF4 does not support parallel I/O, falling back to slower serial I/O",
ImportWarning,
)
[docs]
def supports_netcdf() -> bool:
"""
Returns ``True`` if Heat supports reading from and writing to netCDF4 files, ``False`` otherwise.
"""
return True
def load_netcdf(
path: str,
variable: str,
dtype: datatype = types.float32,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Loads data from a NetCDF4 file. The data may be distributed among multiple processing nodes via the split flag.
Parameters
----------
path : str
Path to the NetCDF4 file to be read.
variable : str
Name of the variable to be read.
dtype : datatype, optional
Data type of the resulting array
split : int or None, optional
The axis along which the data is distributed among the processing cores.
comm : Communication, optional
The communication to use for the data distribution. Defaults to MPI_COMM_WORLD.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
Raises
------
TypeError
If any of the input parameters are not of correct type.
Examples
--------
>>> a = ht.load_netcdf("data.nc", variable="DATA")
>>> a.shape
[0/2] (5,)
[1/2] (5,)
>>> a.lshape
[0/2] (5,)
[1/2] (5,)
>>> b = ht.load_netcdf("data.nc", variable="DATA", split=0)
>>> b.shape
[0/2] (5,)
[1/2] (5,)
>>> b.lshape
[0/2] (3,)
[1/2] (2,)
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(variable, str):
raise TypeError(f"dataset must be str, not {type(variable)}")
if split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")
# infer the canonical heat datatype
dtype = types.canonical_heat_type(dtype)
# determine the device and comm the data will be placed on
device = devices.sanitize_device(device)
comm = sanitize_comm(comm)
# actually load the data
with nc.Dataset(path, "r", parallel=__nc_has_par, comm=comm.handle) as handle:
data = handle[variable]
# prepare meta information
gshape = tuple(data.shape)
split = sanitize_axis(gshape, split)
# chunk up the data portion
_, local_shape, indices = comm.chunk(gshape, split)
balanced = True
if split is None or local_shape[split] > 0:
data = torch.tensor(
data[indices], dtype=dtype.torch_type(), device=device.torch_device
)
else:
data = torch.empty(
local_shape, dtype=dtype.torch_type(), device=device.torch_device
)
return DNDarray(data, gshape, dtype, split, device, comm, balanced)
def save_netcdf(
data: DNDarray,
path: str,
variable: str,
mode: str = "w",
dimension_names: Union[list, tuple, str] = None,
is_unlimited: bool = False,
file_slices: Union[Iterable[int], slice, bool] = slice(None),
**kwargs: Dict[str, object],
):
"""
Saves data to a netCDF4 file. Attempts to utilize parallel I/O if possible.
Parameters
----------
data : DNDarray
The data to be saved on disk.
path : str
Path to the netCDF4 file to be written.
variable : str
Name of the variable the data is saved to.
mode : str, optional
File access mode, one of ``'w', 'a', 'r+'``.
dimension_names : list or tuple or string
Specifies the netCDF Dimensions used by the variable. Ignored if Variable already exists.
is_unlimited : bool, optional
If True, every dimension created for this variable (i.e. doesn't already exist) is unlimited. Already
existing limited dimensions cannot be changed to unlimited and vice versa.
file_slices : integer iterable, slice, ellipsis or bool
Keys used to slice the netCDF Variable, as given in the nc.utils._StartCountStride method.
kwargs : dict, optional
additional arguments passed to the created dataset.
Raises
------
TypeError
If any of the input parameters are not of correct type.
ValueError
If the access mode is not understood or if the number of dimension names does not match the number of
dimensions.
Examples
--------
>>> x = ht.arange(100, split=0)
>>> ht.save_netcdf(x, "data.nc", dataset="DATA")
"""
if not isinstance(data, DNDarray):
raise TypeError(f"data must be heat tensor, not {type(data)}")
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(variable, str):
raise TypeError(f"variable must be str, not {type(path)}")
if dimension_names is None:
dimension_names = [
__NETCDF_DIM_TEMPLATE.format(variable, dim) for dim, _ in enumerate(data.shape)
]
elif isinstance(dimension_names, str):
dimension_names = [dimension_names]
elif isinstance(dimension_names, tuple):
dimension_names = list(dimension_names)
elif not isinstance(dimension_names, list):
raise TypeError(
"dimension_names must be list or tuple or string, not{}".format(
type(dimension_names)
)
)
elif not len(dimension_names) == len(data.shape):
raise ValueError(f"{len(dimension_names)} names given for {len(data.shape)} dimensions")
# we only support a subset of possible modes
if mode not in __VALID_WRITE_MODES:
raise ValueError(f"mode was {mode}, not in possible modes {__VALID_WRITE_MODES}")
failed = 0
excep = None
# chunk the data, if no split is set maximize parallel I/O and chunk first axis
is_split = data.split is not None
_, _, slices = data.comm.chunk(data.gshape, data.split if is_split else 0)
def __get_expanded_split(
shape: Tuple[int], expanded_shape: Tuple[int], split: Optional[int]
) -> int:
"""
Returns the hypothetical split-axis of a dndarray of shape=shape and
split=split if it was expanded to expandedShape by adding empty dimensions.
Parameters
----------
shape : tuple[int]
Shape of a DNDarray.
expanded_shape : tuple[int]
Shape of hypothetical expanded DNDarray.
split : int or None
split-axis of dndarray.
Raises
------
ValueError
If resulting shapes do not match.
"""
if np.prod(shape) != np.prod(expanded_shape):
raise ValueError(f"Shapes {shape} and {expanded_shape} do not have the same size")
if np.prod(shape) == 1: # size 1 array
return split
if len(shape) == len(expanded_shape): # actually not expanded at all
return split
if split is None: # not split at all
return None
# Get indices of non-empty dimensions and squeezed shapes
enumerated = [[i, v] for i, v in enumerate(shape) if v != 1]
ind_nonempty, sq_shape = list(zip(*enumerated)) # transpose
enumerated = [[i, v] for i, v in enumerate(expanded_shape) if v != 1]
ex_ind_nonempty, sq_ex = list(zip(*enumerated)) # transpose
if not sq_shape == sq_ex:
raise ValueError(
f"Shapes {shape} and {expanded_shape} differ in non-empty dimensions"
)
if split in ind_nonempty: # split along non-empty dimension
split_sq = ind_nonempty.index(split) # split-axis in squeezed shape
return ex_ind_nonempty[split_sq]
# split along empty dimension: split doesnt matter, only one process contains data
# return the last empty dimension (in expanded shape) before (the first nonempty dimension after split)
# number of nonempty elems before split
ne_before_split = split - shape[:split].count(1)
ind_ne_after_split = ind_nonempty[
ne_before_split
] # index of (first nonempty element after split) in squeezed shape
return max(
i
for i, v in enumerate(expanded_shape[: max(ex_ind_nonempty[:ind_ne_after_split])])
if v == 1
)
def __merge_slices(
var: nc.Variable,
var_slices: Tuple[int, slice],
data: DNDarray,
data_slices: Optional[Tuple[int, slice]] = None,
) -> Tuple[Union[int, slice]]:
"""
Allows replacing:
``var[var_slices][data_slices] = data``
(a `netcdf4.Variable.__getitem__` and a `numpy.ndarray.__setitem__` call)
with:
``var[ __merge_slices(var, var_slices, data, data_slices) ] = data``
(a single `netcdf4.Variable.__setitem__` call).
This is necessary because performing the former would, in the ``__getitem__``, load the global dataset onto
every process in local ``np.ndarray``s. Then, the ``__setitem__`` would write the local `chunk` into the
``np.ndarray``.
The latter allows the netcdf4 library to parallelize the write-operation by directly using the
`netcdf4.Variable.__setitem__` method.
Parameters
----------
var : nc.Variable
Variable to which data is to be saved.
var_slices : tuple[int, slice]
Keys to pass to the set-operator.
data : DNDarray
Data to be saved.
data_slices: tuple[int, slice]
As returned by the data.comm.chunk method.
"""
slices = data_slices
if slices is None:
_, _, slices = data.comm.chunk(data.gshape, data.split if is_split else 0)
start, count, stride, _ = nc.utils._StartCountStride(
elem=var_slices,
shape=var.shape,
dimensions=var.dimensions,
grp=var.group(),
datashape=data.shape,
put=True,
)
out_shape = nc._netCDF4._out_array_shape(count)
out_split = __get_expanded_split(data.shape, out_shape, data.split)
start, count, stride = start.T, count.T, stride.T # transpose for iteration
stop = start + stride * count
new_slices = []
for begin, end, step in zip(start, stop, stride):
if begin.size == 1:
begin, end, step = begin.item(), end.item(), step.item()
new_slices.append(slice(begin, end, step))
else:
begin, end, step = begin.flatten(), end.flatten(), step.flatten()
new_slices.append(
np.r_[
tuple(
slice(b.item(), e.item(), s.item())
for b, e, s in zip(begin, end, step)
)
]
)
if out_split is not None: # add split-slice
if isinstance(new_slices[out_split], slice):
start, stop, step = (
new_slices[out_split].start,
new_slices[out_split].stop,
new_slices[out_split].step,
)
sliced = range(start, stop, step)[slices[data.split]]
a, b, c = sliced.start, sliced.stop, sliced.step
a = None if a < 0 else a
b = None if b < 0 else b
new_slices[out_split] = slice(a, b, c)
# new_slices[out_split] = sliced
elif isinstance(new_slices[out_split], np.ndarray):
new_slices[out_split] = new_slices[out_split][slices[data.split]]
else:
new_slices[out_split] = np.r_[new_slices[out_split]][slices[data.split]]
return tuple(new_slices)
# attempt to perform parallel I/O if possible
if __nc_has_par:
try:
with nc.Dataset(path, mode, parallel=True, comm=data.comm.handle) as handle:
if variable in handle.variables:
var = handle.variables[variable]
else:
for name, elements in zip(dimension_names, data.shape):
if name not in handle.dimensions:
handle.createDimension(name, elements if not is_unlimited else None)
var = handle.createVariable(
variable, data.dtype.char(), dimension_names, **kwargs
)
merged_slices = __merge_slices(var, file_slices, data)
try:
var[merged_slices] = (
data.larray.cpu() if is_split else data.larray[slices].cpu()
)
except RuntimeError:
var.set_collective(True)
var[merged_slices] = (
data.larray.cpu() if is_split else data.larray[slices].cpu()
)
except Exception as e:
failed = data.comm.rank + 1
excep = e
# otherwise a single rank only write is performed in case of local data (i.e. no split)
elif data.comm.rank == 0:
try:
with nc.Dataset(path, mode) as handle:
if variable in handle.variables:
var = handle.variables[variable]
else:
for name, elements in zip(dimension_names, data.shape):
if name not in handle.dimensions:
handle.createDimension(name, elements if not is_unlimited else None)
var = handle.createVariable(
variable, data.dtype.char(), dimension_names, **kwargs
)
var.set_collective(False) # not possible with non-parallel netcdf
if is_split:
merged_slices = __merge_slices(var, file_slices, data)
var[merged_slices] = data.larray.cpu()
else:
var[file_slices] = data.larray.cpu()
except Exception as e:
failed = 1
excep = e
finally:
if data.comm.size > 1:
data.comm.isend(failed, dest=1)
data.comm.recv()
# non-root
else:
# wait for the previous rank to finish writing its chunk, then write own part
failed = data.comm.recv()
try:
# no MPI, but data is split, we have to serialize the writes
if not failed and is_split:
with nc.Dataset(path, "r+") as handle:
var = handle.variables[variable]
var.set_collective(False) # not possible with non-parallel netcdf
merged_slices = __merge_slices(var, file_slices, data)
var[merged_slices] = data.larray.cpu()
except Exception as e:
failed = data.comm.rank + 1
excep = e
finally:
# ping the next node in the communicator, wrap around to 0 to complete barrier behavior
next_rank = (data.comm.rank + 1) % data.comm.size
data.comm.isend(failed, dest=next_rank)
failed = data.comm.allreduce(failed, op=MPI.MAX)
if failed - 1 == data.comm.rank:
data.comm.bcast(excep, root=failed - 1)
raise excep
elif failed:
excep = data.comm.bcast(excep, root=failed - 1)
excep.args = f"raised by process rank {failed - 1}", *excep.args
raise excep from None # raise the same error but without traceback
# because that is on a different process
DNDarray.save_netcdf = lambda self, path, variable, mode="w", **kwargs: save_netcdf(
self, path, variable, mode, **kwargs
)
DNDarray.save_netcdf.__doc__ = save_netcdf.__doc__
try:
import h5py
except ImportError:
# HDF5 support is optional
def supports_hdf5() -> bool:
"""
Returns ``True`` if Heat supports reading from and writing to HDF5 files, ``False`` otherwise.
"""
return False
else:
# add functions to exports
__all__.extend(["load_hdf5", "save_hdf5", "load_multiple_hdf5"])
# warn the user about serial hdf5
if not h5py.get_config().mpi and MPI_WORLD.rank == 0:
warnings.warn(
"h5py does not support parallel I/O, falling back to slower serial I/O", ImportWarning
)
[docs]
def supports_hdf5() -> bool:
"""
Returns ``True`` if Heat supports reading from and writing to HDF5 files, ``False`` otherwise.
"""
return True
def load_hdf5(
path: str,
dataset: str,
dtype: Optional[datatype] = None,
slices: Optional[Tuple[Optional[slice], ...]] = None,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Loads data from an HDF5 file. The data may be distributed among multiple processing nodes via the split flag.
Parameters
----------
path : str
Path to the HDF5 file to be read.
dataset : str
Name of the dataset to be read.
dtype : datatype, optional
Data type of the resulting array, defaults to the loaded datasets type.
slices : tuple of slice objects, optional
Load only the specified slices of the dataset.
split : int or None, optional
The axis along which the data is distributed among the processing cores.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution.
Raises
------
TypeError
If any of the input parameters are not of correct type
Examples
--------
>>> a = ht.load_hdf5("data.h5", dataset="DATA")
>>> a.shape
[0/2] (5,)
[1/2] (5,)
>>> a.lshape
[0/2] (5,)
[1/2] (5,)
>>> b = ht.load_hdf5("data.h5", dataset="DATA", split=0)
>>> b.shape
[0/2] (5,)
[1/2] (5,)
>>> b.lshape
[0/2] (3,)
[1/2] (2,)
Using the slicing argument:
>>> not_sliced = ht.load_hdf5("other_data.h5", dataset="DATA", split=0)
>>> not_sliced.shape
[0/2] (10,2)
[1/2] (10,2)
>>> not_sliced.lshape
[0/2] (5,2)
[1/2] (5,2)
>>> not_sliced.larray
[0/2] [[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9]]
[1/2] [[10, 11],
[12, 13],
[14, 15],
[16, 17],
[18, 19]]
>>> sliced = ht.load_hdf5("other_data.h5", dataset="DATA", split=0, slices=[slice(8)])
>>> sliced.shape
[0/2] (8,2)
[1/2] (8,2)
>>> sliced.lshape
[0/2] (4,2)
[1/2] (4,2)
>>> sliced.larray
[0/2] [[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7]]
[1/2] [[ 8, 9],
[10, 11],
[12, 13],
[14, 15],
[16, 17]]
>>> sliced = ht.load_hdf5(
... "other_data.h5", dataset="DATA", split=0, slices=[slice(2, 8), slice(0, 1)]
... )
>>> sliced.shape
[0/2] (6,1)
[1/2] (6,1)
>>> sliced.lshape
[0/2] (3,1)
[1/2] (3,1)
>>> sliced.larray
[0/2] [[ 4, ],
[ 6, ],
[ 8, ]]
[1/2] [[10, ],
[12, ],
[14, ]]
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
elif not isinstance(dataset, str):
raise TypeError(f"dataset must be str, not {type(dataset)}")
elif split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")
# determine the comm and device the data will be placed on
device = devices.sanitize_device(device)
comm = sanitize_comm(comm)
# actually load the data from the HDF5 file
with h5py.File(path, "r") as handle:
data = handle[dataset]
gshape = data.shape
new_gshape = tuple()
offsets = [0] * len(gshape)
if dtype is None:
dtype = data.dtype
dtype = types.canonical_heat_type(dtype)
if slices is not None:
for i in range(len(gshape)):
if i < len(slices) and slices[i]:
s = slices[i]
if s.step is not None and s.step != 1:
raise ValueError("Slices with step != 1 are not supported")
new_axis_size, offset = size_from_slice(gshape[i], s)
new_gshape += (new_axis_size,)
offsets[i] = offset
else:
new_gshape += (gshape[i],)
offsets[i] = 0
gshape = new_gshape
dims = len(gshape)
split = sanitize_axis(gshape, split)
_, _, indices = comm.chunk(gshape, split)
if slices is not None:
new_indices = tuple()
for offset, index in zip(offsets, indices):
new_indices += (slice(index.start + offset, index.stop + offset),)
indices = new_indices
balanced = True
if split is None:
data = torch.tensor(
data[indices], dtype=dtype.torch_type(), device=device.torch_device
)
elif indices[split].stop > indices[split].start:
data = torch.tensor(
data[indices], dtype=dtype.torch_type(), device=device.torch_device
)
else:
warnings.warn("More MPI ranks are used then the length of splitting dimension!")
slice1 = tuple(
slice(0, gshape[i]) if i != split else slice(0, 1) for i in range(dims)
)
slice2 = tuple(
slice(0, gshape[i]) if i != split else slice(0, 0) for i in range(dims)
)
data = torch.tensor(
data[slice1], dtype=dtype.torch_type(), device=device.torch_device
)
data = data[slice2]
return DNDarray(data, gshape, dtype, split, device, comm, balanced)
def save_hdf5(
data: DNDarray,
path: str,
dataset: str,
mode: str = "w",
dtype: Optional[datatype] = None,
**kwargs: Dict[str, object],
):
"""
Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible.
Parameters
----------
data : DNDarray
The data to be saved on disk.
path : str
Path to the HDF5 file to be written.
dataset : str
Name of the dataset the data is saved to.
mode : str, optional
File access mode, one of ``'w', 'a', 'r+'``
dtype : datatype, optional
Data type of the saved data
kwargs : dict, optional
Additional arguments passed to the created dataset.
Raises
------
TypeError
If any of the input parameters are not of correct type.
ValueError
If the access mode is not understood.
Examples
--------
>>> x = ht.arange(100, split=0)
>>> ht.save_hdf5(x, "data.h5", dataset="DATA")
"""
if not isinstance(data, DNDarray):
raise TypeError(f"data must be heat tensor, not {type(data)}")
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(dataset, str):
raise TypeError(f"dataset must be str, not {type(path)}")
# we only support a subset of possible modes
if mode not in __VALID_WRITE_MODES:
raise ValueError(f"mode was {mode}, not in possible modes {__VALID_WRITE_MODES}")
# chunk the data, if no split is set maximize parallel I/O and chunk first axis
is_split = data.split is not None
_, _, slices = data.comm.chunk(data.gshape, data.split if is_split else 0)
if dtype is None:
dtype = data.dtype
elif type(dtype) == torch.dtype:
dtype = str(dtype).split(".")[-1]
if type(dtype) is not str:
dtype = dtype.__name__
# attempt to perform parallel I/O if possible
if h5py.get_config().mpi:
with h5py.File(path, mode, driver="mpio", comm=data.comm.handle) as handle:
dset = handle.create_dataset(dataset, data.shape, dtype=dtype, **kwargs)
dset[slices] = data.larray.cpu() if is_split else data.larray[slices].cpu()
# otherwise a single rank only write is performed in case of local data (i.e. no split)
elif data.comm.rank == 0:
with h5py.File(path, mode) as handle:
dset = handle.create_dataset(dataset, data.shape, dtype=dtype, **kwargs)
if is_split:
dset[slices] = data.larray.cpu()
else:
dset[...] = data.larray.cpu()
# ping next rank if it exists
if is_split and data.comm.size > 1:
data.comm.Isend([None, 0, MPI.INT], dest=1)
data.comm.Recv([None, 0, MPI.INT], source=data.comm.size - 1)
# no MPI, but split data is more tricky, we have to serialize the writes
elif is_split:
# wait for the previous rank to finish writing its chunk, then write own part
data.comm.Recv([None, 0, MPI.INT], source=data.comm.rank - 1)
with h5py.File(path, "r+") as handle:
handle[dataset][slices] = data.larray.cpu()
# ping the next node in the communicator, wrap around to 0 to complete barrier behavior
next_rank = (data.comm.rank + 1) % data.comm.size
data.comm.Isend([None, 0, MPI.INT], dest=next_rank)
DNDarray.save_hdf5 = lambda self, path, dataset, mode="w", dtype=None, **kwargs: save_hdf5(
self, path, dataset, mode, **kwargs
)
DNDarray.save_hdf5.__doc__ = save_hdf5.__doc__
def load_multiple_hdf5(
folder: str | Path,
dataset: str,
dtype=None,
sorting_func: Callable[[list[Path]], list[Path]] | None = None,
) -> DNDarray:
"""Loads all .hdf5 or .h5 files inside the given folder into a single DNDarray that is split along axis 0; the arrays
from the different files are concatenated along axis 0.
The files are sorted by the sorting_func, if given, or by the default sorted function.
Parameters
----------
folder : str | Path
folder containing all .h5 or .hdf5 files
dataset : str
dataset to load
dtype : _type_, optional
dtype to create array with, by default None
sorting_func : Callable[[list[Path]], list[Path]] | None, optional
how to sort the files, by default None, ie the sorted function
"""
if not isinstance(folder, (str, Path)):
raise TypeError(f"path must be an String or Path, not {type(folder)}")
if not isinstance(dataset, str):
raise TypeError(f"dataset must be a string not, {type(dataset)}")
if not isinstance(sorting_func, Callable) and sorting_func is not None:
raise TypeError(f"sorting_func must be None or Callable, not {type(sorting_func)}")
if not Path(folder).is_dir():
raise ValueError("Path must be a Folder")
files: list[Path] = list(
filter(lambda x: x.suffix == ".h5" or x.suffix == ".hdf5", Path(folder).iterdir())
)
files: list[Path] = sorted(files) if sorting_func is None else sorting_func(files)
if len(files) == 0:
raise ValueError("No files inside the directory.")
h5_files: list[h5py.File] = [h5py.File(x) for x in files]
shapes: list[tuple[int, ...]] = [x[dataset].shape for x in h5_files]
dtypes: list[np.dtype] = [x[dataset].dtype for x in h5_files]
bytes: list[int] = [x[dataset].dtype.itemsize for x in h5_files]
if len(set(len(x) for x in shapes)) != 1:
raise ValueError("Amount of dimensions of all hdf5 files must be the same")
if dtype is None:
warnings.warn("No explicit dtype given, will use biggest dtype across all files")
dtype = dtypes[np.argmax(bytes)]
shapes_arr = np.array(shapes, dtype=int)
gshape = np.zeros(shapes_arr.shape[1], dtype=int)
for i in range(1, len(gshape)): # rows can be diffrent
if len(set(shapes_arr[:, i])) != 1:
raise ValueError(f"Dimension missmatch on ndim {i + 1}")
gshape[i] = shapes_arr[0][i]
gshape[0] = shapes_arr[:, 0].sum()
gshape = tuple(gshape.astype(int))
ranges: list[tuple[int, int]] = list()
n = 0
for i, shape in enumerate(shapes):
rows = shape[0]
ranges.append((n, n + rows))
n += rows
def find_file_ids_for_range(rslice: slice) -> list[int]:
files = list()
for i, (start, end) in enumerate(ranges):
if start <= rslice.start and rslice.start < end:
files.append(i)
elif start < rslice.stop and rslice.stop <= end:
files.append(i)
elif rslice.start <= start and end <= rslice.stop:
files.append(i)
return files
_, _, slices = MPI_WORLD.chunk(gshape, split=0)
local_slice: slice = slices[0] # global row slice of local rank
local_h5_file_ids: list[int] = find_file_ids_for_range(local_slice)
data: DNDarray = factories.empty(gshape, dtype=dtype, split=0)
n = 0
start_h5_index = local_slice.start - shapes_arr[: local_h5_file_ids[0], 0].sum()
max_n: int = data.larray.shape[0]
for i, idx in enumerate(local_h5_file_ids):
if i == 0:
start_index = start_h5_index
else:
start_index = 0
h5_file: h5py.File = h5_files[idx]
rows: int = shapes[idx][0] - start_index
diff = n + rows - min(n + rows, max_n) # amount of overshoot
data.larray[n : (n + rows - diff)] = torch.from_numpy(
h5_file[dataset][start_index : (start_index + rows - diff)]
)
n += rows - diff
return data
[docs]
def load(
path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]]
) -> DNDarray:
"""
Attempts to load data from a file stored on disk. Attempts to auto-detect the file format by determining the
extension. Supports at least CSV files, HDF5 and netCDF4 are additionally possible if the corresponding libraries
are installed.
Parameters
----------
path : str
Path to the file to be read.
args : list, optional
Additional options passed to the particular functions.
kwargs : dict, optional
Additional options passed to the particular functions.
Raises
------
ValueError
If the file extension is not understood or known.
RuntimeError
If the optional dependency for a file extension is not available.
Examples
--------
>>> ht.load("data.h5", dataset="DATA")
DNDarray([ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.load("data.nc", variable="DATA")
DNDarray([ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.load("my_data.zarr", variable="RECEIVER_1/DATA")
DNDarray([ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981], dtype=ht.float32, device=cpu:0, split=0)
>>> ht.load("my_data.zarr", variable="RECEIVER_*/DATA")
DNDarray([[ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981],
[ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981],
[ 1.0000, 2.7183, 7.3891, 20.0855, 54.5981]], dtype=ht.float32, device=cpu:0, split=0)
See Also
--------
:func:`load_csv` : Loads data from a CSV file.
:func:`load_csv_from_folder` : Loads multiple .csv files into one DNDarray which will be returned.
:func:`load_hdf5` : Loads data from an HDF5 file.
:func:`load_netcdf` : Loads data from a NetCDF4 file.
:func:`load_npy_from_path` : Loads multiple .npy files into one DNDarray which will be returned.
:func:`load_zarr` : Loads zarr-Format into DNDarray which will be returned.
"""
if not isinstance(path, str):
raise TypeError(f"Expected path to be str, but was {type(path)}")
extension = os.path.splitext(path)[-1].strip().lower()
if extension in __CSV_EXTENSION:
return load_csv(path, *args, **kwargs)
elif extension in __HDF5_EXTENSIONS:
if supports_hdf5():
return load_hdf5(path, *args, **kwargs)
else:
raise RuntimeError(f"hdf5 is required for file extension {extension}")
elif extension in __NETCDF_EXTENSIONS:
if supports_netcdf():
return load_netcdf(path, *args, **kwargs)
else:
raise RuntimeError(f"netcdf is required for file extension {extension}")
elif extension in __ZARR_EXTENSIONS:
if supports_zarr():
return load_zarr(path, *args, **kwargs)
else:
raise RuntimeError(f"Package zarr is required for file extension {extension}")
else:
raise ValueError(f"Unsupported file extension {extension}")
[docs]
def load_csv(
path: str,
header_lines: int = 0,
sep: str = ",",
dtype: datatype = types.float32,
encoding: str = "utf-8",
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Loads data from a CSV file. The data will be distributed along the axis 0.
Parameters
----------
path : str
Path to the CSV file to be read.
header_lines : int, optional
The number of columns at the beginning of the file that should not be considered as data.
sep : str, optional
The single ``char`` or ``str`` that separates the values in each row.
dtype : datatype, optional
Data type of the resulting array.
encoding : str, optional
The type of encoding which will be used to interpret the lines of the csv file as strings.
split : int or None : optional
Along which axis the resulting array should be split.
Default is ``None`` which means each node will have the full array.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution, defaults to global default
Raises
------
TypeError
If any of the input parameters are not of correct type.
Examples
--------
>>> import heat as ht
>>> a = ht.load_csv("data.csv")
>>> a.shape
[0/3] (150, 4)
[1/3] (150, 4)
[2/3] (150, 4)
[3/3] (150, 4)
>>> a.lshape
[0/3] (38, 4)
[1/3] (38, 4)
[2/3] (37, 4)
[3/3] (37, 4)
>>> b = ht.load_csv("data.csv", header_lines=10)
>>> b.shape
[0/3] (140, 4)
[1/3] (140, 4)
[2/3] (140, 4)
[3/3] (140, 4)
>>> b.lshape
[0/3] (35, 4)
[1/3] (35, 4)
[2/3] (35, 4)
[3/3] (35, 4)
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(sep, str):
raise TypeError(f"separator must be str, not {type(sep)}")
if not isinstance(header_lines, int):
raise TypeError(f"header_lines must int, not {type(header_lines)}")
if split not in [None, 0, 1]:
raise ValueError(f"split must be in [None, 0, 1], but is {split}")
# infer the type and communicator for the loaded array
dtype = types.canonical_heat_type(dtype)
# determine the comm and device the data will be placed on
device = devices.sanitize_device(device)
comm = sanitize_comm(comm)
file_size = os.stat(path).st_size
rank = comm.rank
size = comm.size
if split is None:
with open(path) as f:
data = f.readlines()
data = data[header_lines:]
result = []
for line in data:
values = line.replace("\n", "").replace("\r", "").split(sep)
values = [float(val) for val in values]
result.append(values)
resulting_tensor = factories.array(
result, dtype=dtype, split=split, device=device, comm=comm
)
elif split == 0:
counts, displs, _ = comm.counts_displs_shape((file_size, 1), 0)
# in case lines are terminated with '\r\n' we need to skip 2 bytes later
lineter_len = 1
# Read a chunk of bytes and count the linebreaks
with open(path, "rb") as f:
f.seek(displs[rank], 0)
line_starts = []
r = f.read(counts[rank])
for pos, line in enumerate(r):
if chr(line) == "\n":
# Check if it is part of '\r\n'
if chr(r[pos - 1]) != "\r":
line_starts.append(pos + 1)
elif chr(line) == "\r":
# check if file line is terminated by '\r\n'
if pos + 1 < len(r) and chr(r[pos + 1]) == "\n":
line_starts.append(pos + 2)
lineter_len = 2
else:
line_starts.append(pos + 1)
if rank == 0:
line_starts = [0] + line_starts
# Find the correct starting point
total_lines = torch.empty(size, dtype=torch.int32)
comm.Allgather(torch.tensor([len(line_starts)], dtype=torch.int32), total_lines)
cumsum = total_lines.cumsum(dim=0).tolist()
start = next(i for i in range(size) if cumsum[i] > header_lines)
if rank < start:
line_starts = []
if rank == start:
rem = header_lines - (0 if start == 0 else cumsum[start - 1])
line_starts = line_starts[rem:]
# Determine the number of columns that each line consists of
if len(line_starts) > 1:
columns = 1
for li in r[line_starts[0] : line_starts[1]]:
if chr(li) == sep:
columns += 1
else:
columns = 0
columns = torch.tensor([columns], dtype=torch.int32)
comm.Allreduce(MPI.IN_PLACE, columns, MPI.MAX)
# Share how far the processes need to reed in their last line
last_line = file_size
if size - start > 1:
if rank == start:
last_line = torch.empty(1, dtype=torch.int32)
comm.Recv(last_line, source=rank + 1)
last_line = last_line.item()
elif rank == size - 1:
first_line = torch.tensor(displs[rank] + line_starts[0] - 1, dtype=torch.int32)
comm.Send(first_line, dest=rank - 1)
elif start < rank < size - 1:
last_line = torch.empty(1, dtype=torch.int32)
first_line = torch.tensor(displs[rank] + line_starts[0] - 1, dtype=torch.int32)
comm.Send(first_line, dest=rank - 1)
comm.Recv(last_line, source=rank + 1)
last_line = last_line.item()
# Create empty tensor and iteratively fill it with the values
local_shape = (len(line_starts), columns)
actual_length = 0
local_tensor = torch.empty(
local_shape, dtype=dtype.torch_type(), device=device.torch_device
)
for ind, start in enumerate(line_starts):
if ind == len(line_starts) - 1:
f.seek(displs[rank] + start, 0)
line = f.read(last_line - displs[rank] - start)
else:
line = r[start : line_starts[ind + 1] - lineter_len]
# Decode byte array
line = line.decode(encoding)
if len(line) > 0:
sep_values = [float(val) for val in line.split(sep)]
local_tensor[actual_length] = torch.tensor(sep_values, dtype=dtype.torch_type())
actual_length += 1
# In case there are some empty lines in the csv file
local_tensor = local_tensor[:actual_length]
total_actual_lines = torch.tensor(
actual_length, dtype=torch.int64, device=local_tensor.device
)
comm.Allreduce(MPI.IN_PLACE, total_actual_lines, MPI.SUM)
gshape = (total_actual_lines.item(), columns[0].item())
resulting_tensor = DNDarray(
local_tensor,
gshape=gshape,
dtype=dtype,
split=0,
device=device,
comm=comm,
balanced=None,
)
resulting_tensor.balance_()
elif split == 1:
data = []
with open(path) as f:
for i in range(header_lines):
f.readline()
line = f.readline()
values = line.replace("\n", "").replace("\r", "").split(sep)
values = [float(val) for val in values]
rows = len(values)
chunk, displs, _ = comm.counts_displs_shape((1, rows), 1)
data.append(values[displs[rank] : displs[rank] + chunk[rank]])
# Read file line by line till EOF reached
for line in iter(f.readline, ""):
values = line.replace("\n", "").replace("\r", "").split(sep)
values = [float(val) for val in values]
data.append(values[displs[rank] : displs[rank] + chunk[rank]])
resulting_tensor = factories.array(data, dtype=dtype, is_split=1, device=device, comm=comm)
return resulting_tensor
[docs]
def save_csv(
data: DNDarray,
path: str,
header_lines: Iterable[str] = None,
sep: str = ",",
decimals: int = -1,
encoding: str = "utf-8",
comm: Optional[Communication] = None,
truncate: bool = True,
):
"""
Saves data to CSV files. Only 2D data, all split axes.
Parameters
----------
data : DNDarray
The DNDarray to be saved to CSV.
path : str
The path as a string.
header_lines : Iterable[str]
Optional iterable of str to prepend at the beginning of the file. No
pound sign or any other comment marker will be inserted.
sep : str
The separator character used in this CSV.
decimals: int
Number of digits after decimal point.
encoding : str
The encoding to be used in this CSV.
comm : Optional[Communication]
An optional object of type Communication to be used.
truncate : bool
Whether to truncate an existing file before writing, i.e. fully overwrite it.
The sane default is True. Setting it to False will not shorten files if
needed and thus may leave garbage at the end of existing files.
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(sep, str):
raise TypeError(f"separator must be str, not {type(sep)}")
# check this to allow None
if not isinstance(header_lines, Iterable) and header_lines is not None:
raise TypeError(f"header_lines must Iterable[str], not {type(header_lines)}")
if data.split not in [None, 0, 1]:
raise ValueError(f"split must be in [None, 0, 1], but is {data.split}")
if os.path.exists(path) and truncate:
if data.comm.rank == 0:
os.truncate(path, 0)
# avoid truncating and writing at the same time
data.comm.handle.Barrier()
amode = MPI.MODE_WRONLY | MPI.MODE_CREATE
csv_out = MPI.File.Open(data.comm.handle, path, amode)
# will be needed as an additional offset later
hl_displacement = 0
if header_lines is not None:
hl_displacement = sum(len(hl) for hl in header_lines)
# count additions everywhere, but write only on rank 0, avoiding reduce op to share final hl_displacement
for hl in header_lines:
if not hl.endswith("\n"):
hl = hl + "\n"
hl_displacement = hl_displacement + 1
if data.comm.rank == 0 and header_lines:
csv_out.Write(hl.encode(encoding))
# formatting and element width
data_min = smin(data).item() # at least min is used twice, so cache it here
data_max = smax(data).item()
sign = 1 if data_min < 0 else 0
if abs(data_max) > 0 or abs(data_min) > 0:
pre_point_digits = int(log10(max(abs(data_max), abs(data_min)))) + 1
else:
pre_point_digits = 1
dec_sep = 1
fmt = ""
if types.issubdtype(data.dtype, types.integer):
decimals = 0
dec_sep = 0
if sign == 1:
fmt = "%%%-dd" % (pre_point_digits + 1)
else:
fmt = "%%%dd" % (pre_point_digits)
elif types.issubdtype(data.dtype, types.floating):
if decimals == -1:
decimals = 7 if data.dtype is types.float32 else 15
if sign == 1:
fmt = "%%%-d.%df" % (pre_point_digits + decimals + 2, decimals)
else:
fmt = "%%%d.%df" % (pre_point_digits + decimals + 1, decimals)
# sign + decimal separator + pre separator digits + decimals (post separator)
item_size = decimals + dec_sep + sign + pre_point_digits
# each item is one position larger than its representation, either b/c of separator or line break
row_width = item_size + 1
if len(data.shape) > 1:
row_width = data.shape[1] * (item_size + 1)
offset = hl_displacement # all splits
if data.split == 0:
_, displs = data.counts_displs()
offset = offset + displs[data.comm.rank] * row_width
elif data.split == 1:
_, displs = data.counts_displs()
offset = offset + displs[data.comm.rank] * (item_size + 1)
for i in range(data.lshape[0]):
# if lshape is of the form (x,), then there will only be a single element per row
if len(data.lshape) == 1:
row = fmt % (data.larray[i])
else:
if data.lshape[1] == 0:
break
row = sep.join(fmt % (item) for item in data.larray[i])
if (
data.split is None
or data.split == 0
or displs[data.comm.rank] + data.lshape[1] == data.shape[1]
):
row = row + "\n"
else:
row = row + sep
if data.split is not None or data.comm.rank == 0:
csv_out.Write_at(offset, row.encode("utf-8"))
offset = offset + row_width
csv_out.Close()
data.comm.handle.Barrier()
[docs]
def save(
data: DNDarray, path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]]
):
"""
Attempts to save data from a :class:`~heat.core.dndarray.DNDarray` to disk. An auto-detection based on the file
format extension is performed.
Parameters
----------
data : DNDarray
The array holding the data to be stored
path : str
Path to the file to be stored.
args : list, optional
Additional options passed to the particular functions.
kwargs : dict, optional
Additional options passed to the particular functions.
Raises
------
ValueError
If the file extension is not understood or known.
RuntimeError
If the optional dependency for a file extension is not available.
Examples
--------
>>> x = ht.arange(100, split=0)
>>> ht.save(x, "data.h5", "DATA", mode="a")
"""
if not isinstance(path, str):
raise TypeError(f"Expected path to be str, but was {type(path)}")
extension = os.path.splitext(path)[-1].strip().lower()
if extension in __HDF5_EXTENSIONS:
if supports_hdf5():
save_hdf5(data, path, *args, **kwargs)
else:
raise RuntimeError(f"hdf5 is required for file extension {extension}")
elif extension in __NETCDF_EXTENSIONS:
if supports_netcdf():
save_netcdf(data, path, *args, **kwargs)
else:
raise RuntimeError(f"netcdf is required for file extension {extension}")
elif extension in __CSV_EXTENSION:
save_csv(data, path, *args, **kwargs)
elif extension in __ZARR_EXTENSIONS:
if supports_zarr():
return save_zarr(data, path, *args, **kwargs)
else:
raise RuntimeError(f"Package zarr is required for file extension {extension}")
else:
raise ValueError(f"Unsupported file extension {extension}")
DNDarray.save = lambda self, path, *args, **kwargs: save(self, path, *args, **kwargs)
DNDarray.save.__doc__ = save.__doc__
[docs]
def load_npy_from_path(
path: str,
dtype: datatype = types.int32,
split: int = 0,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Loads multiple .npy files into one DNDarray which will be returned. The data will be concatenated along the split axis provided as input.
Parameters
----------
path : str
Path to the directory in which .npy-files are located.
dtype : datatype, optional
Data type of the resulting array.
split : int
Along which axis the loaded arrays should be concatenated.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution, default is 'heat.MPI_WORLD'
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
elif split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")
process_number = MPI_WORLD.size
file_list = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, "*.npy"):
file_list.append(file)
n_files = len(file_list)
if n_files == 0:
raise ValueError("No .npy Files were found")
if (n_files < process_number) and (process_number > 1):
raise RuntimeError("Number of processes can't exceed number of files")
rank = MPI_WORLD.rank
if rank < (n_files % process_number):
n_for_procs = n_files // process_number + 1
idx = rank * n_for_procs
else:
n_for_procs = n_files // process_number
idx = rank * n_for_procs + (n_files % process_number)
array_list = [np.load(path + "/" + element) for element in file_list[idx : idx + n_for_procs]]
larray = np.concatenate(array_list, split)
larray = torch.from_numpy(larray)
x = factories.array(larray, dtype=dtype, device=device, is_split=split, comm=comm)
return x
try:
import pandas as pd
except ModuleNotFoundError:
# pandas support is optional
def supports_pandas() -> bool:
"""
Returns ``True`` if pandas is installed , ``False`` otherwise.
"""
return False
else:
# add functions to visible exports
__all__.extend(["load_csv_from_folder"])
def supports_pandas() -> bool:
"""
Returns ``True`` if pandas is installed, ``False`` otherwise.
"""
return True
def load_csv_from_folder(
path: str,
dtype: datatype = types.int32,
split: int = 0,
device: Optional[str] = None,
comm: Optional[Communication] = None,
func: Optional[callable] = None,
) -> DNDarray:
"""
Loads multiple .csv files into one DNDarray which will be returned. The data will be concatenated along the split axis provided as input.
Parameters
----------
path : str
Path to the directory in which .csv-files are located.
dtype : datatype, optional
Data type of the resulting array.
split : int
Along which axis the loaded arrays should be concatenated.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution, default is 'heat.MPI_WORLD'
func : pandas.DataFrame, optional
The function the files have to go through before being added to the array.
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
elif split is not None and not isinstance(split, int):
raise TypeError(f"split must be None or int, not {type(split)}")
elif (func is not None) and not callable(func):
raise TypeError("func needs to be a callable function or None")
process_number = MPI_WORLD.size
file_list = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, "*.csv"):
file_list.append(file)
n_files = len(file_list)
if n_files == 0:
raise ValueError("No .csv Files were found")
if (n_files < process_number) and (process_number > 1):
raise RuntimeError("Number of processes can't exceed number of files")
rank = MPI_WORLD.rank
if rank < (n_files % process_number):
n_for_procs = n_files // process_number + 1
idx = rank * n_for_procs
else:
n_for_procs = n_files // process_number
idx = rank * n_for_procs + (n_files % process_number)
array_list = [
(
(func(pd.read_csv(path + "/" + element))).to_numpy()
if ((func is not None) and (callable(func)))
else (pd.read_csv(path + "/" + element)).to_numpy()
)
for element in file_list[idx : idx + n_for_procs]
]
larray = np.concatenate(array_list, split)
larray = torch.from_numpy(larray)
x = factories.array(larray, dtype=dtype, device=device, is_split=split, comm=comm)
return x
try:
import zarr
except ModuleNotFoundError:
def supports_zarr() -> bool:
"""
Returns ``True`` if zarr is installed, ``False`` otherwise.
"""
return False
else:
__all__.extend(["load_zarr", "save_zarr"])
[docs]
def supports_zarr() -> bool:
"""
Returns ``True`` if zarr is installed, ``False`` otherwise.
"""
return True
def load_zarr(
path: str,
variable: str = None,
split: int = 0,
device: Optional[str] = None,
comm: Optional[Communication] = None,
slices: Union[None, slice, Iterable[Union[slice, None]]] = None,
**kwargs,
) -> DNDarray:
"""
Loads data from a zarr store into DNDarray. `path` can either point to a single zarr array or a zarr group. In the latter case, `variable` must be provided to specify which array in the group to load. If `variable` contains a wildcard pattern (e.g. `RECEIVER_*/DATA`), all matching arrays will be loaded and concatenated along the specified `split` axis.
Parameters
----------
path : str
Path to the directory in which a .zarr-file is located.
variable : str, optional
If the zarr store is a group, the variable (or path to variable) to load from the group.
Can contain a wildcard pattern to load and concatenate arrays stored in slices in different directories.
split : int
Along which axis the loaded arrays should be concatenated.
device : str, optional
The device id on which to place the data, defaults to globally set default device.
comm : Communication, optional
The communication to use for the data distribution, default is 'heat.MPI_WORLD'
slices: Union[None, slice, Iterable[Union[slice, None]]]
Load only a slice of the array instead of everything
**kwargs : Any
extra Arguments to pass to zarr.open
"""
# sanitize inputs
device = devices.sanitize_device(device)
torch_device = device.torch_device
comm = sanitize_comm(comm)
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
if not isinstance(slices, (slice, Iterable)) and slices is not None:
raise TypeError(f"Slices Argument must be slice, tuple or None and not {type(slices)}")
if isinstance(slices, Iterable):
for elem in slices:
if isinstance(elem, slice) or elem is None:
continue
raise TypeError(f"Tuple values of slices must be slice or None, not {type(elem)}")
for extension in __ZARR_EXTENSIONS:
if fnmatch.fnmatch(path, f"*{extension}"):
break
else:
raise ValueError("File has no zarr extension.")
store_path = os.path.join(path, variable) if variable else path
output_dtype = kwargs.pop("dtype", None)
torch_output_dtype = output_dtype.torch_type() if output_dtype else None
if variable and "*" in variable:
# `variable` contains a wildcard pattern
# e.g. data were chunked at write-out and stored in multiple directories
if slices is not None:
raise NotImplementedError("Slicing is not supported when loading with a wildcard.")
base_paths = sorted(glob.glob(store_path))
if not base_paths:
raise FileNotFoundError(
f"Zarr wildcard pattern '{variable}' did not match any arrays in store '{path}'"
)
variable_paths = [os.path.relpath(p, start=path) for p in base_paths]
# each rank reads data from its assigned directories and concatenates locally
# determine which directories to open on rank
dummy_array = factories.empty((len(base_paths),), dtype=types.float32)
_, _, local_dir_slice = dummy_array.comm.chunk(
dummy_array.shape, rank=dummy_array.comm.rank, split=0
)
# load data to torch tensors
local_tensors = []
for i, var_path in enumerate(variable_paths[local_dir_slice[0]]):
local_tensor = torch.from_numpy(zarr.open(path)[var_path][:])
if torch_output_dtype:
local_tensor = local_tensor.to(torch_output_dtype)
local_tensors.append(local_tensor)
# Have rank 0 determine the single-store shape and broadcast it to all ranks for sanitation
target_ndims = torch.zeros(1, dtype=torch.int32)
if dummy_array.comm.rank == 0:
if len(local_tensors) == 0:
raise ValueError(
f"Zarr wildcard pattern '{variable}' did not match any arrays in store '{path}'"
)
# broadcast shape of first local tensor to allow sanitation on empty ranks
target_ndims = torch.tensor(local_tensors[0].ndim, dtype=torch.int32)
dummy_array.comm.Bcast(target_ndims, root=0)
# sanitize split axis
proxy_shape = (1,) * target_ndims.item()
split = sanitize_axis(proxy_shape, axis=split)
# prepare sorted list of all heat datatypes to convert them to integer for communication with MPI
all_heat_type_names = [dtype.__name__ for dtype in types.__get_all_heat_types()]
all_heat_type_names.sort()
# concatenate locally
if len(local_tensors) >= 1:
if len(local_tensors) == 1:
local_tensor = local_tensors[0]
else:
local_tensor = torch.cat(local_tensors, dim=split if split is not None else 0)
empty_ranks = torch.tensor([0], dtype=torch.int32)
ht_type_code = all_heat_type_names.index(
types.canonical_heat_type(local_tensor.dtype).__name__
)
else:
# no local tensors i.e. no data assigned to rank
local_tensor = torch.empty((0,))
empty_ranks = torch.tensor([1], dtype=torch.int32)
# dummy dtype code
ht_type_code = -1
# check for empty ranks
dummy_array.comm.Allreduce(MPI.IN_PLACE, empty_ranks, op=MPI.SUM)
if empty_ranks.item() > 0:
# fix local shape and dtype of empty tensors, otherwise DNDarray construction will fail
# Rank 0 broadcasts the info to all other ranks
target_shape = torch.zeros(
(
1,
target_ndims.item() + 1,
),
dtype=torch.int64,
)
if local_tensor.numel() > 0:
target_shape[0, :-1] = torch.tensor(local_tensor.shape, dtype=torch.int64)
# encode dtype as last entry
target_shape[0, -1] = ht_type_code
# share info about target shape and dtype
target_shapes = torch.zeros(
(dummy_array.comm.size, target_ndims.item() + 1), dtype=torch.int64
)
dummy_array.comm.Allgather(target_shape, target_shapes)
if local_tensor.numel() == 0:
ht_type_code = target_shapes[0, -1].item()
target_shape = target_shapes[0, :-1].clone()
target_shape[split] = 0
ht_type = getattr(types, all_heat_type_names[ht_type_code])
local_tensor = torch.empty(
tuple(target_shape.tolist()), dtype=ht_type.torch_type()
)
# discard dtype code column
target_shapes = target_shapes[:, :-1]
# calculate global array shape
out_gshape = target_shapes[0, :].clone()
out_gshape[split] = target_shapes[:, split].sum().item()
# wrap local tensors in DNDarray
dndarray = DNDarray(
local_tensor.to(device=torch_device),
gshape=tuple(out_gshape.tolist()),
dtype=output_dtype
if output_dtype
else types.canonical_heat_type(local_tensor.dtype),
split=split,
device=device,
comm=comm,
balanced=False,
)
else:
# all ranks are populated, create DNDarray directly
dndarray = factories.array(local_tensor, is_split=split, device=device, comm=comm)
dndarray.balance_()
return dndarray
# standard single zarr array
arr: zarr.Array = zarr.open_array(store=store_path, **kwargs)
shape = arr.shape
if isinstance(slices, slice) or slices is None:
slices = [slices]
if len(shape) < len(slices):
raise ValueError(
f"slices Argument has more arguments than the length of the shape of the array. {len(shape)} < {len(slices)}"
)
slices = [elem if elem is not None else slice(None) for elem in slices]
slices.extend([slice(None) for _ in range(abs(len(slices) - len(shape)))])
dtype = types.canonical_heat_type(arr.dtype)
split = sanitize_axis(shape, axis=split)
# slices = tuple(slice(*tslice.indices(length)) for length, tslice in zip(shape, slices))
slices = tuple(slices)
shape = [len(range(*tslice.indices(length))) for length, tslice in zip(shape, slices)]
offset, local_shape, local_slices = comm.chunk(shape, split)
return factories.array(
arr[slices][local_slices], dtype=dtype, is_split=split, device=device, comm=comm
)
def save_zarr(dndarray: DNDarray, path: str, overwrite: bool = False, **kwargs) -> None:
"""
Writes the DNDArray into the zarr-format.
Parameters
----------
dndarray : DNDarray
DNDArray to save.
path : str
path to save to.
overwrite : bool
Wether to overwrite an existing array.
**kwargs : Any
extra Arguments to pass to zarr.open and zarr.create
Raises
------
TypeError
- If given parameters do not match or have conflicting information.
- If it already exists and no overwrite is specified.
Notes
-----
Zarr functions by chunking the data, were a chunk is a file inside the store.
The problem ist that only one process writes to it at a time. Therefore when two
processes try to write to the same chunk one will fail, unless the other finishes before
the other starts.
To alleviate it we can define the chunk sizes ourselves. To do this we just get the lowest size of
the distributed axis, ex: split=0 with a (4,4) shape with a worldsize of 4 you would chunk it with (1,4).
A problem arises when a process gets a bigger chunk and interferes with another process. Example:
N_PROCS = 4
SHAPE = (9,10)
SPLIT = 0
CHUNKS => (2,10)
In this problem one process will have a write region of 3 rows and therefore be able to either not write
or overwrite what another process does therefore destroying the parallel write as it would at the end load
2 chunks to write 3 rows.
To counter act this we just set the chunk size in the split axis to 1. This allows for no overwrites but can
cripple write speeds and or even speed it up.
Another Problem with this approach is that we tell zarr have full chunks, i.e if array has shape (10_000, 10_000)
and we split it at axis=0 with 4 processes we have chunks of (2_500, 10_000). Zarr will load the whole chunk into
memory making it memory intensive and probably inefficient. Better approach would be to have a smaller chunk size
for example half of it but that cannot be determined at all times so the current approach is a compromise.
Another Problem is the split=None scenario. In this case every processs has the same data, so only one needs to write
so we ignore chunking and let zarr decide the chunk size and let only one process, aka rank=0 write.
To avoid errors when using NumPy arrays as chunk shape, the chunks argument is only passed to zarr.create if it is
not None. This prevents issues with ambiguous truth values or attribute errors on None.
"""
if not isinstance(path, str):
raise TypeError(f"path must be str, not {type(path)}")
for extension in __ZARR_EXTENSIONS:
if fnmatch.fnmatch(path, f"*{extension}"):
break
else:
raise ValueError("path does not end on an Zarr extension.")
if os.path.exists(path) and not overwrite:
raise RuntimeError("Given Path already exists.")
if MPI_WORLD.rank == 0:
if dndarray.split is None or MPI_WORLD.size == 1:
chunks = None
else:
chunks = np.array(dndarray.gshape)
axis = dndarray.split
if chunks[axis] % MPI_WORLD.size != 0:
chunks[axis] = 1
else:
chunks[axis] //= MPI_WORLD.size
CODEC_LIMIT_BYTES = 2**31 - 1 # PR#1766
for _ in range(
10
): # Use for loop instead of while true for better handling of edge cases
byte_size = reduce(operator.mul, chunks, 1) * dndarray.larray.element_size()
if byte_size > CODEC_LIMIT_BYTES:
if chunks[axis] % 2 == 0:
chunks[axis] /= 2
continue
else:
chunks[axis] = 1
break
else:
break
else:
chunks[axis] = 1
warnings.warn(
"Calculation of chunk size for zarr format unexpectadly defaulted to 1 on the split axis"
)
dtype = dndarray.dtype.char()
zarr_create_kwargs = {
"store": path,
"shape": dndarray.gshape,
"dtype": dtype,
"overwrite": overwrite,
**kwargs,
}
if chunks is not None:
zarr_create_kwargs["chunks"] = chunks.tolist()
zarr_array = zarr.create(**zarr_create_kwargs)
# Wait for the file creation to finish
MPI_WORLD.Barrier()
zarr_array = zarr.open(store=path, mode="r+", **kwargs)
if dndarray.split is not None:
_, _, slices = MPI_WORLD.chunk(dndarray.gshape, dndarray.split)
zarr_array[slices] = (
dndarray.larray.cpu().numpy() # Numpy array needed as zarr can only understand numpy dtypes and infers it.
)
else:
if MPI_WORLD.rank == 0:
zarr_array[:] = dndarray.larray.cpu().numpy()
MPI_WORLD.Barrier()