"""Provides a collection of signal-processing operations"""
import torch
import numpy as np
from typing import Union
import warnings
from .dndarray import DNDarray
from .types import promote_types, heat_type_is_exact, heat_type_is_realfloating, issubdtype
from .types import unsignedinteger, float64, float32
from .manipulations import pad, flip
from .factories import array, zeros
from .sanitation import sanitize_in_min_max_nd
import torch.nn.functional as fc
__all__ = ["convolve", "convolve2d"]
def _sanitize_conv_input(
a: DNDarray,
v: DNDarray,
stride: Union[int, tuple[int, int]],
mode: str,
convolution_dim: int = 1,
) -> tuple[DNDarray, DNDarray]:
"""
Check and preprocess input data.
Parameters
----------
a : scalar, array_like, DNDarray
Input signal data.
v : scalar, array_like, DNDarray
Input filter mask.
stride : scalar, tuple
Stride along each axis convolution is applied.
mode : str
Convolution mode "full", "same" or "valid"
convolution_dim : int, optional
Number of dimension along which convolution will be applied, affects what input_check looks for. Default 1
Returns
-------
tuple
A tuple containing the processed input signal 'a' and filter mask 'v'.
Raises
------
TypeError
If 'a' or 'v' have unsupported data types.
ValueError
If 'v' is larger than 'a' in only one dimension for convolutions other than 1D
If mode not supported
If mode==same and stride > 1
Description
-----------
This function takes two inputs, 'a' (signal data) and 'v' (filter mask), and performs the following checks and
preprocessing steps:
1. Check if 'a' and 'v' are scalars. If they are, convert them into DNDarray arrays.
2. Check if 'a' and 'v' are instances of the 'DNDarray' class. If not, attempt to convert them into DNDarray arrays.
If conversion is not possible, raise a TypeError.
3. Determine the promoted data type for 'a' and 'v' based on their existing data types. Convert 'a' and 'v' to this
promoted data type to ensure consistent data types.
4. Check if data type is supported in torch given the device
5. Check if filter is smaller or equal signal, flip if necessary
6. Check mode and check mode "same" against even sized kernels
7. Check stride for negative entries and against mode
8. Return a tuple containing the processed 'a' and 'v'.
"""
# Check if 'a' is a scalar and convert to a DNDarray if necessary
if np.isscalar(a):
a = array([[a]])
while a.ndim > convolution_dim:
a = a.squeeze(-1)
# Check if 'v' is a scalar and convert to a DNDarray if necessary
if np.isscalar(v):
v = array([[v]])
while v.ndim > convolution_dim:
v = v.squeeze(-1)
# Check if 'a' is not an instance of DNDarray and try to convert it to a DNDarray array
if not isinstance(a, DNDarray):
try:
a = array(a)
except TypeError:
raise TypeError(f"non-supported type for signal: {type(a)}")
# Check if 'v' is not an instance of DNDarray and try to convert it to a NumPy array
if not isinstance(v, DNDarray):
try:
v = array(v)
except TypeError:
raise TypeError(f"non-supported type for filter: {type(v)}")
# Check if sufficient number of dimensions available
sanitize_in_min_max_nd(a, convolution_dim)
sanitize_in_min_max_nd(v, convolution_dim)
# Determine the promoted data type for 'a' and 'v' and convert them to this data type
promoted_type = promote_types(a.dtype, v.dtype)
# promoted_type must be of integer or floating for convolution
if not (
heat_type_is_exact(promoted_type) or heat_type_is_realfloating(promoted_type)
) or issubdtype(promoted_type, unsignedinteger):
raise TypeError(
f"Data type supported for convolution. Signal type {a.dtype}, Kernel type {v.dtype}, Promoted type {promoted_type}"
)
# cast to float32 if on mps or cuda in certain circumstances
if a.larray.is_mps and promoted_type == float64:
# cannot cast to float64 on MPS
promoted_type = float32
warnings.warn(
f"Promoted type float64 is not supported on MPS. Signal and kernel will be cast to {promoted_type} instead."
)
elif a.larray.is_cuda and not heat_type_is_realfloating(promoted_type):
promoted_type = promote_types(promoted_type, float32)
warnings.warn(
f"Only floating operations supported on CUDA. Signal and kernel will be cast to {promoted_type}",
RuntimeWarning,
)
# cast
a = a.astype(promoted_type)
v = v.astype(promoted_type)
# check if the filter is longer than the signal and swap them if necessary
v_shape = v.shape[-convolution_dim:]
a_shape = a.shape[-convolution_dim:]
if all(v_s >= a_s for v_s, a_s in zip(v_shape, a_shape)):
if not all(v_s == a_s for v_s, a_s in zip(v_shape, a_shape)):
a, v = v, a
v_shape = v.shape[-convolution_dim:]
a_shape = a.shape[-convolution_dim:]
if any(v_s > a_s for v_s, a_s in zip(v_shape, a_shape)):
raise ValueError(
f"Filter size must not be larger in one convolved dimension and smaller in the other. Signal: {a.shape}, Filter: {v.shape}"
)
# check mode against even kernel
if mode not in ("full", "valid", "same"):
raise ValueError(f"Only 'full', 'valid' or 'same' as mode are allowed, got {mode}.")
if mode == "same" and any(v_s % 2 == 0 for v_s in v_shape):
raise ValueError("Mode 'same' cannot be used with even-sized kernel.")
# check mode and stride for value errors
if convolution_dim == 1:
if stride < 1:
raise ValueError("Stride must be positive")
if stride > 1 and mode == "same":
raise ValueError("Stride must be 1 for mode 'same'")
else:
if any(s < 1 for s in stride):
raise ValueError("Stride must be positive for all convolution dimensions")
if any(s > 1 for s in stride) and mode == "same":
raise ValueError(f"Stride must be {tuple([1] * convolution_dim)} for mode 'same'")
# Return the processed 'a' and 'v' as a tuple
return a, v
def _conv_batchprocessing_check(a: DNDarray, v: DNDarray, convolution_dim: int) -> bool:
"""
Check if batch proccessing applies, default is False (no batch processing)
Parameters
----------
a : scalar, array_like, DNDarray
Input signal data.
v : scalar, array_like, DNDarray
Input filter mask.
convolution_dim : int
Number of dimension along which convolution will be applied, affects what input_check looks for. Default 1
Returns
-------
bool
Boolean if batch processing applies
"""
batch_processing = False
if a.ndim > convolution_dim:
# batch processing requires 1D filter OR matching batch dimensions for signal and filter
batch_dims = a.shape[:-convolution_dim]
# verify that the filter shape is consistent with the signal
if v.ndim > convolution_dim:
v_batch = v.shape[:-convolution_dim]
if any(v_s != b_s for v_s, b_s in zip(batch_dims, v_batch)):
raise ValueError(
f"Batch dimensions of signal and filter must match. Signal: {a.shape}, Filter: {v.shape}"
)
if a.is_distributed():
if any(a.split == a.ndim - forbidden for forbidden in range(1, convolution_dim + 1)):
raise ValueError(
"Please distribute the signal along the batch dimension, not the signal dimension. For in-place redistribution use the `DNDarray.resplit_()` method with `axis=0`"
)
batch_processing = True
if (not batch_processing) and (v.ndim > convolution_dim):
raise ValueError(
f"{convolution_dim}-D convolution without batch processing only supported for {convolution_dim}-dimensional signal and kernel. Signal: {a.shape}, Filter: {v.shape}"
)
return batch_processing
[docs]
def convolve(a: DNDarray, v: DNDarray, mode: str = "full", stride: int = 1) -> DNDarray:
"""
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars.
Unlike `numpy.signal.convolve`, if ``a`` and/or ``v`` have more than one dimension, batch-convolution along the last dimension will be attempted. See `Examples` below.
Parameters
----------
a : DNDarray or scalar
One- or N-dimensional signal ``DNDarray`` of shape (..., N), or scalar. If ``a`` has more than one dimension, it will be treated as a batch of 1D signals.
Distribution along the batch dimension is required for distributed batch processing. See the examples for details.
v : DNDarray or scalar
One- or N-dimensional filter weight `DNDarray` of shape (..., M), or scalar. If ``v`` has more than one dimension, it will be treated as a batch of 1D filter weights.
The batch dimension(s) of ``v`` must match the batch dimension(s) of ``a``.
mode : str
Can be 'full', 'valid', or 'same'. Default is 'full'.
'full':
Returns the convolution at
each point of overlap, with a length of '(N+M-2)//stride+1'. At
the end-points of the convolution, the signals do not overlap
completely, and boundary effects may be seen.
'same':
Mode 'same' returns output of length 'N'. Boundary
effects are still visible. This mode is not supported for
even-sized filter weights
'valid':
Mode 'valid' returns output of length '(N-M)//stride+1'. The
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.
stride : int
Stride of the convolution. Must be a positive integer. Default is 1.
Stride must be 1 for mode 'same'.
Examples
--------
Note how the convolution operator flips the second array
before "sliding" the two across one another:
>>> a = ht.ones(5)
>>> v = ht.arange(3).astype(ht.float)
>>> ht.convolve(a, v, mode="full")
DNDarray([0., 1., 3., 3., 3., 3., 2.])
>>> ht.convolve(a, v, mode="same")
DNDarray([1., 3., 3., 3., 3.])
>>> ht.convolve(a, v, mode="valid")
DNDarray([3., 3., 3.])
>>> ht.convolve(a, v, stride=2)
DNDarray([0., 3., 3., 2.])
>>> ht.convolve(a, v, mode="valid", stride=2)
DNDarray([3., 3.])
>>> a = ht.ones(10, split=0)
>>> v = ht.arange(3, split=0).astype(ht.float)
>>> ht.convolve(a, v, mode="valid")
DNDarray([3., 3., 3., 3., 3., 3., 3., 3.])
[0/3] DNDarray([3., 3., 3.])
[1/3] DNDarray([3., 3., 3.])
[2/3] DNDarray([3., 3.])
>>> a = ht.ones(10, split=0)
>>> v = ht.arange(3, split=0)
>>> ht.convolve(a, v)
DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0)
[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])
>>> a = ht.arange(50, dtype=ht.float64, split=0)
>>> a = a.reshape(10, 5) # 10 signals of length 5
>>> v = ht.arange(3)
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with filter v
DNDarray([[ 0., 0., 1., 4., 7., 10., 8.],
[ 0., 5., 16., 19., 22., 25., 18.],
[ 0., 10., 31., 34., 37., 40., 28.],
[ 0., 15., 46., 49., 52., 55., 38.],
[ 0., 20., 61., 64., 67., 70., 48.],
[ 0., 25., 76., 79., 82., 85., 58.],
[ 0., 30., 91., 94., 97., 100., 68.],
[ 0., 35., 106., 109., 112., 115., 78.],
[ 0., 40., 121., 124., 127., 130., 88.],
[ 0., 45., 136., 139., 142., 145., 98.]], dtype=ht.float64, device=cpu:0, split=0)
>>> v = ht.random.randint(0, 3, (10, 3), split=0) # 10 filters of length 3
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with 10 filters
DNDarray([[ 0., 0., 2., 4., 6., 8., 0.],
[ 5., 6., 7., 8., 9., 0., 0.],
[ 20., 42., 56., 61., 66., 41., 14.],
[ 0., 15., 16., 17., 18., 19., 0.],
[ 20., 61., 64., 67., 70., 48., 0.],
[ 50., 52., 104., 108., 112., 56., 58.],
[ 0., 30., 61., 63., 65., 67., 34.],
[ 35., 106., 109., 112., 115., 78., 0.],
[ 0., 40., 81., 83., 85., 87., 44.],
[ 0., 0., 45., 46., 47., 48., 49.]], dtype=ht.float64, device=cpu:0, split=0)
"""
a, v = _sanitize_conv_input(a, v, stride, mode, 1)
# assess whether to perform batch processing, default is False (no batch processing)
batch_processing = _conv_batchprocessing_check(a, v, 1)
if batch_processing and a.is_distributed() and v.is_distributed():
if v.ndim == 1:
# gather filter to all ranks
v.resplit_(axis=None)
else:
v.resplit_(axis=a.split)
# ensure balanced kernel
if not (v.is_balanced()):
raise ValueError("Only balanced kernel weights are allowed")
# calculate pad size according to mode
if mode == "full":
pad_size = v.shape[-1] - 1
elif mode == "same":
pad_size = v.shape[-1] // 2
elif mode == "valid":
pad_size = 0
gshape = (a.shape[-1] + 2 * pad_size - v.shape[-1]) // stride + 1
if v.is_distributed() and stride > 1:
gshape_stride_1 = a.shape[-1] + 2 * pad_size - v.shape[-1] + 1
if batch_processing:
# all operations are local torch operations, only the last dimension is convolved
local_a = a.larray
local_v = v.larray
# flip filter for convolution, as Pytorch conv1d computes correlations
local_v = torch.flip(local_v, [-1])
local_batch_dims = tuple(local_a.shape[:-1])
# reshape signal and filter to 3D for Pytorch conv1d function
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.conv1d.html
local_a = local_a.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_a.device), dim=0).item(),
local_a.shape[-1],
)
channels = local_a.shape[0]
if v.ndim > 1:
local_v = local_v.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_v.device), dim=0).item(),
local_v.shape[-1],
)
local_v = local_v.unsqueeze(1)
else:
local_v = local_v.unsqueeze(0).unsqueeze(0).expand(local_a.shape[0], 1, -1)
# add batch dimension to signal
local_a = local_a.unsqueeze(0)
# apply torch convolution operator if local signal isn't empty
if torch.prod(torch.tensor(local_a.shape, device=local_a.device)) > 0:
local_convolved = fc.conv1d(
local_a, local_v, padding=pad_size, groups=channels, stride=stride
)
else:
empty_shape = tuple(local_a.shape[:-1] + (gshape,))
local_convolved = torch.empty(empty_shape, dtype=local_a.dtype, device=local_a.device)
# unpack 3D result into original shape
local_convolved = local_convolved.squeeze(0)
local_convolved = local_convolved.reshape(local_batch_dims + (gshape,))
# wrap result in DNDarray
convolved = array(local_convolved, is_split=a.split, device=a.device, comm=a.comm)
return convolved
# pad signal with zeros
a = pad(a, pad_size, "constant", 0)
# compute halo size
halo_size = torch.max(v.lshape_map[:, -1]).item() // 2
if a.is_distributed():
if (v.lshape_map[:, a.split] > a.lshape_map[:, a.split]).any():
raise ValueError(
"Local chunk of filter weight is larger than the local chunks of signal"
)
# fetch halos and store them in a.halo_next/a.halo_prev
a.get_halo(halo_size)
# apply halos to local array
signal = a.array_with_halos
# shift signal based on global kernel starts for any rank but first
if stride > 1 and not v.is_distributed():
if a.comm.rank == 0:
local_index = 0
else:
local_index = torch.sum(a.lshape_map[: a.comm.rank, 0]).item() - halo_size
local_index = local_index % stride
if local_index != 0:
local_index = stride - local_index
# even kernels can produces doubles
if v.shape[-1] % 2 == 0 and local_index == 0:
local_index = stride
signal = signal[local_index:]
else:
signal = a.larray
# flip filter for convolution as Pytorch conv1d computes correlations
v = flip(v, [0])
if v.larray.shape[0] != v.lshape_map[0, 0]:
# pads weights if input kernel is uneven
target = torch.zeros(v.lshape_map[0][0], dtype=v.larray.dtype, device=v.larray.device)
pad_size = v.lshape_map[0][0] - v.larray.shape[0]
target[pad_size:] = v.larray
weight = target
else:
weight = v.larray
t_v = weight # stores temporary weight
# make signal and filter weight 3D for Pytorch conv1d function
signal = signal.reshape(1, 1, signal.shape[0])
weight = weight.reshape(1, 1, weight.shape[0])
if v.is_distributed():
size = v.comm.size
# any stride is a subset of stride 1
if stride > 1:
gshape = gshape_stride_1
for r in range(size):
rec_v = t_v.clone()
v.comm.Bcast(rec_v, root=r)
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0])
local_signal_filtered = fc.conv1d(signal, t_v1, stride=1)
# unpack 3D result into 1D
local_signal_filtered = local_signal_filtered[0, 0, :]
if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0:
local_signal_filtered = local_signal_filtered[1:]
# accumulate filtered signal on the fly
global_signal_filtered = array(
local_signal_filtered, is_split=0, device=a.device, comm=a.comm
)
if r == 0:
# initialize signal_filtered, starting point of slice
signal_filtered = zeros(
gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm
)
start_idx = 0
# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
try:
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape]
except (ValueError, TypeError):
signal_filtered = (
signal_filtered + global_signal_filtered[start_idx : start_idx + gshape]
)
if r != size - 1:
start_idx += v.lshape_map[r + 1][0].item()
# any stride is a subset of arrays of stride 1
if stride > 1:
signal_filtered = signal_filtered[::stride]
return signal_filtered
else:
# apply torch convolution operator
if signal.shape[-1] >= weight.shape[-1]:
signal_filtered = fc.conv1d(signal, weight, stride=stride)
# unpack 3D result into 1D
signal_filtered = signal_filtered[0, 0, :]
else:
signal_filtered = torch.tensor([], device=str(signal.device))
# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and v.shape[0] % 2 == 0 and stride == 1:
signal_filtered = signal_filtered[1:]
return DNDarray(
signal_filtered,
(gshape,),
signal_filtered.dtype,
a.split,
a.device,
a.comm,
balanced=False,
).astype(a.dtype.torch_type())
[docs]
def convolve2d(
a: DNDarray,
v: DNDarray,
mode: str = "full",
stride: tuple[int, int] = (1, 1),
) -> DNDarray:
"""
Returns the discrete, linear convolution of two two-dimensional HeAT tensors.
Recommendation:
For better memory consumption, it is recommended to pad the array before running the convolution in mode "valid"
Parameters
----------
a : scalar, array_like, DNDarray
Two-dimensional signal, float precision required on gpu
v : scalar, array_like, DNDarray
Two-dimensional filter mask. float precision required on gpu
mode : {'full', 'valid', 'same'}, optional
'full':
By default, mode is 'full'. This returns the convolution at
each point of overlap, with an output shape of (N+M-1,). At
the end-points of the convolution, the signals do not overlap
completely, and boundary effects may be seen.
'same':
Mode 'same' returns output of length 'N'. Boundary
effects are still visible. This mode is not supported for
even sized filter weights
'valid':
Mode 'valid' returns output of length 'N-M+1'. The
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.
stride: Tuple(int,int), optional
Stride of the convolution in (x,y) direction. Default is (1,1).
Returns
-------
out : DNDarray
Discrete, linear convolution of 'a' and 'v', balanced
Note : If the filter weight is larger
than fitting into memory, using the FFT for convolution is recommended.
Example
--------
>>> a = ht.ones((5, 5))
>>> v = ht.ones((3, 3))
>>> ht.convolve2d(a, v, mode="valid")
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)
>>> a = ht.ones((5, 5), split=1))
>>> v = ht.ones((3, 3), split=1))
>>> ht.convolve2d(a, v)
DNDarray([[1., 2., 3., 3., 3., 2., 1.],
[2., 4., 6., 6., 6., 4., 2.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[2., 4., 6., 6., 6., 4., 2.],
[1., 2., 3., 3., 3., 2., 1.]], dtype=ht.float32, device=cpu:0, split=1)
>>> a = ht.ones((5, 5), split=0)).astype(ht.float)
>>> v = ht.ones((3, 3), split=0)).astype(ht.float)
>>> stride = (1, 2)
>>> ht.convolve2d(a, v, stride=stride)
DNDarray([[1., 3., 3., 1.],
[2., 6., 6., 2.],
[3., 9., 9., 3.],
[3., 9., 9., 3.],
[3., 9., 9., 3.],
[2., 6., 6., 2.],
[1., 3., 3., 1.]], dtype=ht.float32, device=cpu:0, split=0)
"""
# check type and size of input
a, v = _sanitize_conv_input(a, v, stride, mode, 2)
# assess whether to perform batch processing, default is False (no batch processing)
batch_processing = _conv_batchprocessing_check(a, v, 2)
if a.is_distributed() and v.is_distributed():
if batch_processing and v.ndim == 2:
# gather filter to all ranks
v.resplit_(axis=None)
else:
v.resplit_(axis=a.split)
# ensure balanced kernel
if not (v.is_balanced()):
raise ValueError("Only balanced kernel weights are allowed")
# calculate pad size according to mode
if mode == "full":
pad_size = [v.shape[i] - 1 for i in range(-2, 0)]
elif mode == "same":
pad_size = [v.shape[i] // 2 for i in range(-2, 0)]
elif mode == "valid":
pad_size = [0, 0]
gshape = tuple(
[(a.shape[i] + 2 * pad_size[i] - v.shape[i]) // stride[i] + 1 for i in range(-2, 0)]
)
if v.is_distributed() and any(s > 1 for s in stride):
gshape_stride_1 = tuple(
[(a.shape[i] + 2 * pad_size[i] - v.shape[i]) + 1 for i in range(-2, 0)]
)
if batch_processing:
# all operations are local torch operations, only the last dimension is convolved
local_a = a.larray
local_v = v.larray
# flip filter for convolution, as Pytorch conv1d computes correlations
local_v = torch.flip(local_v, [-2, -1])
local_batch_dims = tuple(local_a.shape[:-2])
# reshape signal and filter to 3D for Pytorch conv1d function
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.conv1d.html
local_a = local_a.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_a.device), dim=0).item(),
local_a.shape[-2],
local_a.shape[-1],
)
channels = local_a.shape[0]
if v.ndim > 2:
local_v = local_v.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_v.device), dim=0).item(),
local_v.shape[-2],
local_v.shape[-1],
)
local_v = local_v.unsqueeze(1)
else:
local_v = (
local_v.unsqueeze(0)
.unsqueeze(0)
.expand(local_a.shape[0], 1, local_v.shape[-2], local_v.shape[-1])
)
# add batch dimension to signal
local_a = local_a.unsqueeze(0)
# apply torch convolution operator if local signal isn't empty, add zero padding
if torch.prod(torch.tensor(local_a.shape, device=local_a.device)) > 0:
local_convolved = fc.conv2d(
local_a, local_v, groups=channels, stride=stride, padding=tuple(pad_size)
)
else:
empty_shape = tuple(local_a.shape[:-1] + (gshape[-2],) + (gshape[-1],))
local_convolved = torch.empty(empty_shape, dtype=local_a.dtype, device=local_a.device)
# unpack 3D result into original shape
local_convolved = local_convolved.squeeze(0)
local_convolved = local_convolved.reshape(local_batch_dims + (gshape[-2],) + (gshape[-1],))
# wrap result in DNDarray
convolved = array(local_convolved, is_split=a.split, device=a.device, comm=a.comm)
return convolved
# pad signal with zeros
if not mode == "valid":
pad_array = ((pad_size[0], pad_size[0]), (pad_size[1], pad_size[1]))
a = pad(a, pad_array)
# no batch processing
if a.is_distributed():
if (v.lshape_map[:, a.split] > a.lshape_map[:, a.split]).any():
raise ValueError(
"Local chunk of filter weight is larger than the local chunks of signal"
)
# compute halo size
halo_size = int(v.lshape_map[0][a.split]) // 2
# fetch halos and store them in a.halo_next/a.halo_prev
a.get_halo(halo_size)
# apply halos to local array
signal = a.array_with_halos
else:
# get local array in case of non-distributed a
signal = a.larray
# flip filter for convolution as PyTorch conv2d computes correlation
no_dims = len(v.shape)
v = flip(v, [no_dims - 2, no_dims - 1])
# compute weight size
if v.is_distributed() and v.larray.shape[v.split] != v.lshape_map[0, v.split]:
# pads weights if input kernel is uneven
target = torch.zeros(tuple(v.lshape_map[0]), dtype=v.larray.dtype, device=v.larray.device)
v_pad_size = v.lshape_map[0][v.split] - v.larray.shape[v.split]
if v.split == 0:
target[v_pad_size:, :] = v.larray
else:
target[:, v_pad_size:] = v.larray
weight = target
else:
weight = v.larray
t_v = weight
# make signal and filter weight 4D for Pytorch conv2d function
signal = signal.reshape(1, 1, signal.shape[-2], signal.shape[-1])
weight = weight.reshape(1, 1, weight.shape[-2], weight.shape[-1])
if v.is_distributed():
size = v.comm.size
split_axis = v.split
# any stride is a subset of stride 1
if any(s > 1 for s in stride):
gshape = gshape_stride_1
# convoluted signal
signal_filtered = zeros(gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm)
for r in range(size):
rec_v = t_v.clone()
v.comm.Bcast(rec_v, root=r)
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0], rec_v.shape[1])
# apply torch convolution operator
local_signal_filtered = fc.conv2d(signal, t_v1, stride=1)
# unpack 3D result into 2D
local_signal_filtered = local_signal_filtered[0, 0, :, :]
# if kernel shape along split axis is even we need to get rid of duplicated values
if a.is_distributed() and v.comm.rank != 0 and v.lshape_map[0][split_axis] % 2 == 0:
if split_axis == 0:
local_signal_filtered = local_signal_filtered[1:, :]
else:
local_signal_filtered = local_signal_filtered[:, 1:]
# compute offset for local_signal_filtered
if r > 0:
v_pad_size = v.lshape_map[0][v.split] - v.lshape_map[r, v.split]
start_idx = int(torch.sum(v.lshape_map[:r, split_axis]).item() - v_pad_size)
else:
start_idx = 0
# if a is distributed, results have to be communicated across ranks
if a.is_distributed():
filter_results = array(
local_signal_filtered, is_split=a.split, device=a.device, comm=a.comm
)
else:
filter_results = local_signal_filtered
# apply start_idx
if split_axis == 0:
filter_results = filter_results[start_idx : start_idx + gshape[0], :]
else:
filter_results = filter_results[:, start_idx : start_idx + gshape[1]]
# add results
try:
if a.is_distributed():
signal_filtered += filter_results
else:
signal_filtered.larray += filter_results
except (ValueError, TypeError):
if a.is_distributed():
signal_filtered = signal_filtered + filter_results
else:
signal_filtered.larray = signal_filtered.larray + filter_results
if any(s > 1 for s in stride):
signal_filtered = signal_filtered[:: stride[0], :: stride[1]]
if a.is_distributed():
signal_filtered.balance_()
return signal_filtered
else:
# shift signal based on global kernel starts for any rank but first if stride > 1
if a.is_distributed() and stride[a.split] > 1:
if a.comm.rank == 0:
local_index = 0
else:
# lshape map does not know about padding, compute pad_offset for last rank
# pad_offset = pad_size[a.split] if a.comm.rank == a.comm.size - 1 else 0
local_index = torch.sum(a.lshape_map[: a.comm.rank, a.split]).item() - halo_size
local_index = local_index % stride[a.split]
if local_index != 0:
local_index = stride[a.split] - local_index
# even kernels can produces doubles
if v.shape[a.split] % 2 == 0 and local_index == 0:
local_index = stride[a.split]
if a.split == 0:
signal = signal[:, :, local_index:, :]
else:
signal = signal[:, :, :, local_index:]
if all(a_s >= v_s for v_s, a_s in zip(weight.shape[-2:], signal.shape[-2:])):
# apply torch convolution operator
signal_filtered = fc.conv2d(signal, weight, stride=stride)
# unpack 4D result into 2D
signal_filtered = signal_filtered[0, 0, :, :]
else:
empty_shape = list(signal.shape[-2:])
empty_shape[a.split] = 0
signal_filtered = torch.empty(tuple(empty_shape), device=str(signal.device))
# if kernel shape along split axis is even we need to get rid of duplicated values
if (
a.is_distributed()
and a.comm.rank != 0
and stride[a.split] == 1
and v.shape[a.split] % 2 == 0
):
if a.split == 0:
signal_filtered = signal_filtered[1:, :]
elif a.split == 1:
signal_filtered = signal_filtered[:, 1:]
result = DNDarray(
signal_filtered.contiguous(),
gshape,
a.dtype,
a.split,
a.device,
a.comm,
balanced=False,
).astype(a.dtype.torch_type())
if result.is_distributed():
result.balance_()
return result