Source code for heat.fft.fft

"""Provides a collection of Discrete Fast Fourier Transforms (DFFT) and their inverses."""

import torch

from ..core.communication import MPI
from ..core.dndarray import DNDarray
from ..core.stride_tricks import sanitize_axis
from ..core.types import heat_type_is_exact, heat_type_of, canonical_heat_type, float32
from ..core.factories import array, arange
from ..core.devices import Device

from typing import Type, Union, Tuple, Iterable, Optional

__all__ = [
    "fft",
    "fft2",
    "fftfreq",
    "fftn",
    "fftshift",
    "hfft",
    "hfft2",
    "hfftn",
    "ifft",
    "ifft2",
    "ifftn",
    "ifftshift",
    "ihfft",
    "ihfft2",
    "ihfftn",
    "irfft",
    "irfft2",
    "irfftn",
    "rfft",
    "rfft2",
    "rfftfreq",
    "rfftn",
]


def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray:
    """
    Helper function for 1-dimensional FFT.
    """
    try:
        local_x = x.larray
    except AttributeError:
        raise TypeError(f"x must be a DNDarray, is {type(x)}")
    original_split = x.split

    # sanitize kwargs
    axis = kwargs.get("axis")
    try:
        axis = sanitize_axis(x.gshape, axis)
    except ValueError as e:
        raise IndexError(e)
    if isinstance(axis, tuple) and len(axis) > 1:
        raise TypeError(f"axis must be an integer, got {axis}")
    n = kwargs.get("n", None)
    if n is None:
        n = x.shape[axis]
    norm = kwargs.get("norm", None)

    # calculate output shape:
    # if operation requires real input, output size of last transformed dimension is the Nyquist frequency
    output_shape = list(x.shape)
    real_to_generic_fft_ops = {
        torch.fft.rfft: torch.fft.fft,
        torch.fft.ihfft: torch.fft.ifft,
    }
    real_op = fft_op in real_to_generic_fft_ops
    if real_op:
        nyquist_freq = n // 2 + 1
        output_shape[axis] = nyquist_freq
    else:
        output_shape[axis] = n

    fft_along_split = original_split == axis
    # FFT along non-split axis
    if not x.is_distributed() or not fft_along_split:
        if local_x.numel() == 0:
            # empty tensor, return empty tensor with consistent shape
            local_shape = output_shape.copy()
            local_shape[original_split] = 0
            torch_result = torch.empty(
                tuple(local_shape), dtype=local_x.dtype, device=local_x.device
            )
        else:
            torch_result = fft_op(local_x, n=n, dim=axis, norm=norm)
        return DNDarray(
            torch_result,
            gshape=tuple(output_shape),
            dtype=heat_type_of(torch_result),
            split=original_split,
            device=x.device,
            comm=x.comm,
            balanced=x.balanced,
        )

    # FFT along split axis
    if original_split != 0:
        # transpose x so redistribution starts from axis 0
        transpose_axes = list(range(x.ndim))
        transpose_axes[0], transpose_axes[original_split] = (
            transpose_axes[original_split],
            transpose_axes[0],
        )
        x = x.transpose(transpose_axes)

    # transform decomposition: split axis first, then the rest

    # redistribute x
    if x.ndim > 1:
        _ = x.resplit(axis=1)
    else:
        _ = x.resplit(axis=None)

    # if operation requires real input, switch to generic transform
    if real_op:
        fft_op = real_to_generic_fft_ops[fft_op]
    # FFT along axis 0 (now non-split)
    ht_result = __fft_op(_, fft_op, n=n, axis=0, norm=norm)
    del _
    # redistribute partial result back to axis 0
    ht_result.resplit_(axis=0)
    if original_split != 0:
        # transpose x, partial_result back to original shape
        x = x.transpose(transpose_axes)
        ht_result = ht_result.transpose(transpose_axes)

    if real_op:
        # discard elements beyond Nyquist frequency on last transformed axis
        nyquist_slice = [slice(None)] * ht_result.ndim
        nyquist_slice[axis] = slice(0, nyquist_freq)
        ht_result = ht_result[(nyquist_slice)].balance_()

    return ht_result


def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
    """
    Helper function for fftn
    """
    try:
        local_x = x.larray
    except AttributeError:
        raise TypeError("x must be a DNDarray, is {}".format(type(x)))

    original_split = x.split
    output_shape = list(x.shape)
    shift_op = fftn_op in [torch.fft.fftshift, torch.fft.ifftshift]
    inverse_real_op = fftn_op in [torch.fft.irfftn, torch.fft.irfft2]

    torch_has_ihfftn = hasattr(torch.fft, "ihfftn")

    if torch_has_ihfftn:
        real_to_generic_fftn_ops = {
            torch.fft.rfftn: torch.fft.fftn,
            torch.fft.rfft2: torch.fft.fft2,
            torch.fft.ihfftn: torch.fft.ifftn,
            torch.fft.ihfft2: torch.fft.ifft2,
        }
    else:  # pragma: no cover
        real_to_generic_fftn_ops = {
            torch.fft.rfftn: torch.fft.fftn,
            torch.fft.rfft2: torch.fft.fft2,
        }

    real_op = fftn_op in real_to_generic_fftn_ops

    # sanitize kwargs
    if shift_op:
        # only keyword argument `axes` is supported
        axes = kwargs.get("axes", None)
        if axes is None:
            axes = tuple(range(x.ndim))
        else:
            try:
                axes = sanitize_axis(x.gshape, axes)
            except ValueError as e:
                raise IndexError(e)
        torch_kwargs = {"dim": axes}
    else:
        s = kwargs.get("s", None)
        if s is not None and len(s) > x.ndim:
            raise ValueError(
                f"Input is {x.ndim}-dimensional, so s can be at most {x.ndim} elements long. Got {len(s)} elements instead."
            )
        axes = kwargs.get("axes", None)
        if axes is None:
            if s is not None:
                axes = tuple(range(x.ndim)[-len(s) :])
            else:
                axes = tuple(range(x.ndim))
        else:
            try:
                axes = sanitize_axis(x.gshape, axes)
            except ValueError as e:
                raise IndexError(e)
        if s is None:
            s = tuple(x.shape[axis] for axis in axes)
        norm = kwargs.get("norm", None)
        for i, axis in enumerate(axes):
            output_shape[axis] = s[i]
        torch_kwargs = {"s": s, "dim": axes, "norm": norm}

    repeated_axes = axes is not None and len(axes) != len(set(axes))
    if repeated_axes:
        raise NotImplementedError("Multiple transforms over the same axis not implemented yet.")

    # calculate output shape:
    # if operation requires real input, output size of last transformed dimension is the Nyquist frequency
    if real_op:
        nyquist_freq = s[-1] // 2 + 1
        output_shape[axes[-1]] = nyquist_freq

    fft_along_split = original_split in axes
    # FFT along non-split axes only
    if not x.is_distributed() or not fft_along_split:
        if local_x.numel() == 0:
            # empty tensor, return empty tensor with consistent shape and dtype
            local_shape = output_shape.copy()
            local_shape[original_split] = 0
            if inverse_real_op:
                output_dtype = local_x.real.dtype
            else:
                # local_x is empty, memory footprint not an issue
                _ = local_x * 1j
                output_dtype = _.dtype
            torch_result = torch.empty(
                tuple(local_shape), dtype=output_dtype, device=local_x.device
            )
        else:
            torch_result = fftn_op(local_x, **torch_kwargs)
        return DNDarray(
            torch_result,
            gshape=tuple(output_shape),
            dtype=heat_type_of(torch_result),
            split=original_split,
            device=x.device,
            comm=x.comm,
            balanced=x.balanced,
        )

    # FFT along split axis
    if original_split != 0:
        # transpose x so redistribution starts from axis 0
        transpose_axes = list(range(x.ndim))
        transpose_axes[0], transpose_axes[original_split] = (
            transpose_axes[original_split],
            transpose_axes[0],
        )
        x = x.transpose(transpose_axes)

    # original split is 0 and fft is along axis 0
    if x.ndim == 1:
        _ = x.resplit(axis=None)
        ht_result = __fftn_op(_, fftn_op, **kwargs).resplit_(axis=0)
        del _
        return ht_result

    # transform decomposition: split axis first, then the rest
    # redistribute x from axis 0 to 1
    _ = x.resplit(axis=1)
    # FFT along axis 0 (now non-split)
    split_index = axes.index(original_split)
    if shift_op:
        partial_ht_result = __fftn_op(_, fftn_op, axes=(0,))
    else:
        partial_s = (s[split_index],)
        partial_ht_result = __fftn_op(_, fftn_op, s=partial_s, axes=(0,), norm=norm)
    output_shape[original_split] = partial_ht_result.shape[0]
    del _
    # redistribute partial result from axis 1 to 0
    partial_ht_result.resplit_(axis=0)
    if original_split != 0:
        # transpose x, partial_ht_result back to original shape
        x = x.transpose(transpose_axes)
        partial_ht_result = partial_ht_result.transpose(transpose_axes)

    # now apply FFT along leftover (non-split) axes
    axes = list(axes)
    axes.remove(original_split)
    axes = tuple(axes)
    if shift_op:
        ht_result = __fftn_op(partial_ht_result, fftn_op, axes=axes)
    else:
        if s is not None:
            s = list(s)
            s = s[:split_index] + s[split_index + 1 :]
            s = tuple(s)
        # if fft operation requires real input, switch to generic operation for the second pass
        if real_op:
            fftn_op = real_to_generic_fftn_ops[fftn_op]
        ht_result = __fftn_op(partial_ht_result, fftn_op, s=s, axes=axes, norm=norm)
    del partial_ht_result
    return ht_result


def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray:
    """
    Helper function for ``fftfreq`` and ``rfftfreq`` operations.
    """
    n = kwargs.get("n", None)
    d = kwargs.get("d", None)
    dtype = kwargs.get("dtype", None)
    split = kwargs.get("split", None)
    device = kwargs.get("device", None)
    comm = kwargs.get("comm", None)

    if not isinstance(n, int):
        raise ValueError(f"n must be an integer, is {type(n)}")
    if not isinstance(d, (int, float)):
        if isinstance(d, complex):
            # numpy supports complex d, torch doesn't
            raise NotImplementedError("Support for complex d not implemented yet.")
        raise TypeError(f"d must be a scalar, is {type(d)}")
    if dtype is not None:
        if heat_type_is_exact(dtype):
            raise TypeError(f"dtype must be a float or complex type, is {dtype}")
        # extract torch dtype from heat dtype
        try:
            torch_dtype = dtype.torch_type()
        except AttributeError:
            raise TypeError(f"dtype must be a heat dtype, is {type(dtype)}")
    else:
        torch_dtype = None

    # early out for non-distributed fftfreq
    if split is None:
        return array(fftfreq_op(n, d=d, dtype=torch_dtype), device=device, split=None, comm=comm)

    # distributed fftfreq
    if split != 0:
        raise IndexError(f"`fftfreq` returns a 1-D array, `split` must be 0 or None, is {split}")

    # calculate parameters of the global frequency spectrum
    n_is_even = n % 2 == 0
    if n_is_even:
        middle_channel = n // 2
    else:
        middle_channel = n // 2 + 1

    # allocate global fftfreq array
    # if real operation, return only positive frequencies
    freq_dtype = (
        canonical_heat_type(torch.promote_types(torch_dtype, torch.float32))
        if torch_dtype is not None
        else float32
    )
    if fftfreq_op == torch.fft.rfftfreq:
        freqs = arange(middle_channel, dtype=freq_dtype, device=device, split=split, comm=comm)
    else:
        freqs = arange(n, dtype=freq_dtype, device=device, split=split, comm=comm)
        # second half of fftfreq returns negative frequencies in inverse order
        freqs[middle_channel:] -= n

    # calculate global frequencies
    channel_width = n * d
    freqs /= channel_width
    return freqs


def __real_fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray:
    """
    Helper function for real 1-D FFTs.
    """
    if x.larray.is_complex():
        raise TypeError(f"Input array must be real, is {x.dtype}.")
    return __fft_op(x, fft_op, **kwargs)


def __real_fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
    """
    Helper function for real N-D FFTs.
    """
    if x.larray.is_complex():
        raise TypeError(f"Input array must be real, is {x.dtype}.")
    return __fftn_op(x, fftn_op, **kwargs)


[docs] def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last axis is transformed, while the remaining axes are left unchanged. Parameters ---------- x : DNDarray Input array, can be complex. WARNING: If x is 1-D and distributed, the entire array is copied on each MPI process. See Notes. n : int, optional Length of the transformed axis of the output. If not given, the length is assumed to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is truncated. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if `x` has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ifft` : inverse 1-dimensional FFT :func:`fft2` : 2-dimensional FFT :func:`fftn` : N-dimensional FFT :func:`rfft` : 1-dimensional FFT of a real signal :func:`hfft` : 1-dimensional FFT of a Hermitian symmetric sequence :func:`fftfreq` : frequency bins for given FFT parameters :func:`rfftfreq` : frequency bins for real FFT Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them in one go (see :func:`fft2`). """ return __fft_op(x, torch.fft.fft, n=n, axis=axis, norm=norm)
[docs] def fft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last two axes are transformed, while the remaining axes are left unchanged. Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, int], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is (-2, -1)) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ifft2` : inverse 2-dimensional FFT :func:`fft` : 1-dimensional FFT :func:`fftn` : N-dimensional FFT :func:`rfft2` : 2-dimensional FFT of a real signal :func:`hfft2` : 2-dimensional FFT of a Hermitian symmetric sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __fftn_op(x, torch.fft.fft2, s=s, axes=axes, norm=norm)
[docs] def fftfreq( n: int, d: Union[int, float] = 1.0, dtype: Optional[Type] = None, split: Optional[int] = None, device: Optional[Union[str, Device]] = None, comm: Optional[MPI.Comm] = None, ) -> DNDarray: """ Return the Discrete Fourier Transform sample frequencies for a signal of size ``n``. The returned ``DNDarray`` contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. Parameters ---------- n : int Window length. d : Union[int, float], optional Sample spacing (inverse of the sampling rate). Defaults to 1. dtype : Type, optional The desired data type of the output. Defaults to `ht.float32`. split : int, optional The axis along which to split the result. Can be None or 0, as the output is 1-dimensional. Defaults to None, i.e. non-distributed output. device : str or Device, optional The device on which to place the output. If not given, the output is placed on the current device. comm : MPI.Comm, optional The MPI communicator to use for distributing the output. If not given, the default communicator is used. See Also -------- :func:`rfftfreq` : frequency bins for :func:`rfft` """ return __fftfreq_op( torch.fft.fftfreq, n=n, d=d, dtype=dtype, split=split, device=device, comm=comm )
[docs] def fftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the N-dimensional discrete Fourier Transform. This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, ...], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is None) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ifftn` : inverse N-dimensional FFT :func:`fft` : 1-dimensional FFT :func:`fft2` : 2-dimensional FFT :func:`rfftn` : N-dimensional FFT of a real signal :func:`hfftn` : N-dimensional FFT of a Hermitian symmetric sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __fftn_op(x, torch.fft.fftn, s=s, axes=axes, norm=norm)
[docs] def fftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> DNDarray: """ Shift the zero-frequency component to the center of the spectrum. This function swaps half-spaces for all axes listed (defaults to all). Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. Parameters ---------- x : DNDarray Input array axes : int or Iterable[int], optional Axes over which to shift. Default is None, which shifts all axes. See Also -------- :func:`ifftshift` : The inverse of `fftshift`. Notes ----- This function requires MPI communication if the input array is distributed and the split axis is shifted. """ return __fftn_op(x, torch.fft.fftshift, axes=axes)
[docs] def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional discrete Fourier Transform of a Hermitian symmetric signal. This function computes the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last axis is transformed, while the remaining axes are left unchanged. The input signal is assumed to be Hermitian-symmetric, i.e. `x[..., i] = x[..., -i].conj()`. Parameters ---------- x : DNDarray Input array n : int, optional Length of the transformed axis of the output. If `n` is not None, the input array is either zero-padded or trimmed to length `n` before the transform. Default: `2 * (x.shape[axis] - 1)`. axis : int, optional Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence :func:`fft` : 1-dimensional FFT :func:`rfft` : 1-dimensional FFT of a real signal Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. """ if n is None: n = 2 * (x.shape[axis] - 1) return __fft_op(x, torch.fft.hfft, n=n, axis=axis, norm=norm)
[docs] def hfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the 2-dimensional discrete Fourier Transform of a Hermitian symmetric signal. This function computes the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last two axes are transformed, while the remaining axes are left unchanged. The input signal is assumed to be Hermitian-symmetric, i.e. `x[..., i] = x[..., -i].conj()`. Parameters ---------- x : DNDarray Input array s : Tuple[int, int], optional Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. axes : Tuple[int, int], optional Axes over which to compute the FFT. If not given, the last two dimensions are transformed. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: (-2, -1). norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence :func:`fft2` : 2-dimensional FFT :func:`rfft2` : 2-dimensional FFT of a real signal Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ if s is None: s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) return __fftn_op(x, torch.fft.hfft2, s=s, axes=axes, norm=norm)
[docs] def hfftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the N-dimensional discrete Fourier Transform of a Hermitian symmetric signal. This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed. Parameters ---------- x : DNDarray Input array s : Tuple[int, ...], optional Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. axes : Tuple[int, ...], optional Axes over which to compute the FFT. If not given, all dimensions are transformed. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence :func:`fftn` : N-dimensional FFT :func:`rfftn` : N-dimensional FFT of a real signal Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ if s is None: if axes is None: s = list(x.shape[i] for i in range(x.ndim)) else: s = list(x.shape[i] for i in axes) s[-1] = 2 * (s[-1] - 1) s = tuple(s) return __fftn_op(x, torch.fft.hfftn, s=s, axes=axes, norm=norm)
[docs] def ifft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional inverse discrete Fourier Transform. Parameters ---------- x : DNDarray Input array, can be complex n : int, optional Length of the transformed axis of the output. If not given, the length is taken to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`fft` : forward 1-dimensional FFT :func:`ifft2` : inverse 2-dimensional FFT :func:`ifftn` : inverse N-dimensional FFT :func:`irfft` : inverse 1-dimensional FFT of a real sequence :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian symmetric sequence Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`ifft2`). """ return __fft_op(x, torch.fft.ifft, n=n, axis=axis, norm=norm)
[docs] def ifft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the 2-dimensional inverse discrete Fourier Transform. Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, int], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: (-2, -1). norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`fft2` : forward 2-dimensional FFT :func:`ifft` : inverse 1-dimensional FFT :func:`ifftn` : inverse N-dimensional FFT :func:`irfft2` : inverse 2-dimensional FFT of a real sequence :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian symmetric sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __fftn_op(x, torch.fft.ifft2, s=s, axes=axes, norm=norm)
[docs] def ifftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the N-dimensional inverse discrete Fourier Transform. Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, ...], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`fftn` : forward N-dimensional FFT :func:`ifft` : inverse 1-dimensional FFT :func:`ifft2` : inverse 2-dimensional FFT :func:`irfftn` : inverse N-dimensional FFT of a real sequence :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian symmetric sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __fftn_op(x, torch.fft.ifftn, s=s, axes=axes, norm=norm)
[docs] def ifftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> DNDarray: """ The inverse of fftshift. Parameters ---------- x : DNDarray Input array axes : int or Iterable[int], optional Axes over which to shift. Default is None, which shifts all axes. See Also -------- :func:`fftshift` : Shift the zero-frequency component to the center of the spectrum. Notes ----- This function requires MPI communication if the input array is distributed and the split axis is shifted. """ return __fftn_op(x, torch.fft.ifftshift, axes=axes)
[docs] def ihfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. Parameters ---------- x : DNDarray Input array, must be real n : int, optional Length of the transformed axis of the output. If not given, the length is taken to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence :func:`rfft` : 1-dimensional FFT of a real signal :func:`irfft` : inverse 1-dimensional FFT of a real sequence Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. """ return __real_fft_op(x, torch.fft.ihfft, n=n, axis=axis, norm=norm)
[docs] def ihfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the inverse of a 2-dimensional discrete Fourier Transform of a Hermitian-symmetric signal. The output is Hermitian-symmetric. Requires torch >= 1.11.0. Parameters ---------- x : DNDarray Input array, must be real s : Tuple[int, int], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default is (-2, -1). norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence :func:`rfft2` : 2-dimensional FFT of a real signal :func:`irfft2` : inverse 2-dimensional FFT of a real sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __real_fftn_op(x, torch.fft.ihfft2, s=s, axes=axes, norm=norm)
[docs] def ihfftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the inverse of a N-dimensional discrete Fourier Transform of Hermitian-symmetric signal. The output is Hermitian-symmetric. Requires torch >= 1.11.0. Parameters ---------- x : DNDarray Input array, must be real s : Tuple[int, ...], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence :func:`rfftn` : N-dimensional FFT of a real signal :func:`irfftn` : inverse N-dimensional FFT of a real sequence Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __real_fftn_op(x, torch.fft.ihfftn, s=s, axes=axes, norm=norm)
[docs] def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the inverse of a one-dimensional discrete Fourier Transform of real signal. The output is real. Parameters ---------- x : DNDarray Input array, can be complex n : int, optional Length of the transformed axis of the output. If not given, the length is taken to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`irfft2` : inverse 2-dimensional FFT :func:`irfftn` : inverse N-dimensional FFT :func:`rfft` : 1-dimensional FFT of a real signal :func:`hfft` : 1-dimensional FFT of a Hermitian symmetric sequence :func:`fft` : 1-dimensional FFT Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`irfft2`). """ if n is None: n = 2 * (x.shape[axis] - 1) return __fft_op(x, torch.fft.irfft, n=n, axis=axis, norm=norm)
[docs] def irfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the inverse of a 2-dimensional discrete real Fourier Transform. The output is real. Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, int], optional Shape of the output along the transformed axes. axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default is (-2, -1)) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". See Also -------- :func:`irfft` : inverse 1-dimensional FFT :func:`irfftn` : inverse N-dimensional FFT :func:`rfft2` : 2-dimensional FFT of a real signal :func:`hfft2` : 2-dimensional FFT of a Hermitian symmetric sequence :func:`fft2` : 2-dimensional FFT Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ if s is None: s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) return __fftn_op(x, torch.fft.irfft2, s=s, axes=axes, norm=norm)
[docs] def irfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the inverse of an N-dimensional discrete Fourier Transform of real signal. The output is real. Parameters ---------- x : DNDarray Input array, assumed to be Hermitian-symmetric along the transformed axes, with the last transformed axis only containing the positive half of the frequencies. s : Tuple[int, ...], optional Shape of the output along the transformed axes. If ``s`` is not specified, the last transposed axis is reconstructued in full, i.e. `s[-1] = 2 * (x.shape[axes[-1]] - 1)`. axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is None) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ if s is None: if axes is None: s = list(x.shape[i] for i in range(x.ndim)) else: s = list(x.shape[i] for i in axes) s[-1] = 2 * (s[-1] - 1) s = tuple(s) return __fftn_op(x, torch.fft.irfftn, s=s, axes=axes, norm=norm)
[docs] def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional discrete Fourier Transform of real input. The output is Hermitian-symmetric. Parameters ---------- x : DNDarray Input array, must be real. n : int, optional Length of the transformed axis of the output. If not given, the length is taken to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`rfft2`). """ return __real_fft_op(x, torch.fft.rfft, n=n, axis=axis, norm=norm)
[docs] def rfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ Compute the 2-dimensional discrete Fourier Transform of real input. The output is Hermitian-symmetric. Parameters ---------- x : DNDarray Input array, must be real. s : Tuple[int, int], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is (-2, -1)) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __real_fftn_op(x, torch.fft.rfft2, s=s, axes=axes, norm=norm)
[docs] def rfftfreq( n: int, d: Union[int, float] = 1.0, dtype: Optional[Type] = None, split: Optional[int] = None, device: Optional[Union[str, Device]] = None, comm: Optional[MPI.Comm] = None, ) -> DNDarray: """ Return the Discrete Fourier Transform sample frequencies. The returned float DNDarray contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. Parameters ---------- n : int Window length. d : Union[int, float], optional Sample spacing (inverse of the sampling rate). Defaults to 1. dtype : Type, optional The desired data type of the output. Defaults to `float32`. split : int, optional The axis along which to split the result. If not given, the result is not split. device : str or Device, optional The device on which to place the output. If not given, the output is placed on the current device. comm : MPI.Comm, optional The MPI communicator to use for distributing the output. If not given, the default communicator is used. """ return __fftfreq_op( torch.fft.rfftfreq, n=n, d=d, dtype=dtype, split=split, device=device, comm=comm )
[docs] def rfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ Compute the N-dimensional discrete Fourier Transform of real input. By default, all axes are transformed, with the real transform performed over the last axis, while the remaining transforms are complex. The output is Hermitian-symmetric, with the last transformed axis having length `s[-1] // 2 + 1` (the positive part of the spectrum). Parameters ---------- x : DNDarray Input array, must be real. s : Tuple[int, ...], optional Shape of the output along the transformed axes. axes : Tuple[int, ...], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is None) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __real_fftn_op(x, torch.fft.rfftn, s=s, axes=axes, norm=norm)