Source code for heat.array_api._inspection

import heat.core.devices as ht_devices

from types import SimpleNamespace
from heat.core.types import (
    bool,
    complex64,
    complex128,
    float16,
    float32,
    float64,
    int8,
    int16,
    int32,
    int64,
    uint8,
)


[docs] def __array_namespace_info__(): """Returns a namespace with Array API namespace inspection utilities.""" info = SimpleNamespace() info.capabilities = capabilities info.default_device = default_device info.default_dtypes = default_dtypes info.devices = devices info.dtypes = dtypes return info
[docs] def capabilities(): """Returns a dictionary of array library capabilities.""" return {"boolean indexing": True, "data-dependent shapes": True, "max dimensions": 64}
[docs] def default_device(): """Returns the default device.""" return ht_devices.get_device()
[docs] def default_dtypes(*, device=None): """Returns a dictionary containing default data types.""" if device is None: device = default_device() if not isinstance(device, ht_devices.Device): raise ValueError(f"Device not understood: {device}") if "mps" in device.torch_device: return { "real floating": float32, "complex floating": complex64, "integral": int32, "indexing": int32, } return { "real floating": float32, "complex floating": complex64, "integral": int64, "indexing": int64, } raise ValueError(f"Unsupported device: {device}")
[docs] def devices(): """Returns a list of supported devices which are available at runtime.""" if hasattr(ht_devices, "gpu"): return (ht_devices.cpu, ht_devices.gpu) else: return (ht_devices.cpu,)
[docs] def dtypes(*, device=None, kind=None): """Returns a dictionary of supported Array API data types""" if device is None: device = default_device() if not isinstance(device, ht_devices.Device): raise ValueError(f"Device not understood: {device}") if "mps" in device.torch_device: if kind is None: return { "bool": bool, "int8": int8, "int16": int16, "int32": int32, "uint8": uint8, "float32": float32, "complex64": complex64, } if kind == "bool": return { "bool": bool, } if kind == "signed integer": return { "int8": int8, "int16": int16, "int32": int32, } if kind == "unsigned integer": return { "uint8": uint8, } if kind == "integral": return { "int8": int8, "int16": int16, "int32": int32, "uint8": uint8, } if kind == "real floating": return { "float32": float32, } if kind == "complex floating": return { "complex64": complex64, } if kind == "numeric": return { "int8": int8, "int16": int16, "int32": int32, "uint8": uint8, "float32": float32, "complex64": complex64, } if kind is None: return { "bool": bool, "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, "float32": float32, "float64": float64, "complex64": complex64, "complex128": complex128, } if kind == "bool": return { "bool": bool, } if kind == "signed integer": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, } if kind == "unsigned integer": return { "uint8": uint8, } if kind == "integral": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, } if kind == "real floating": return { "float32": float32, "float64": float64, } if kind == "complex floating": return { "complex64": complex64, "complex128": complex128, } if kind == "numeric": return { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, "float32": float32, "float64": float64, "complex64": complex64, "complex128": complex128, } if isinstance(kind, tuple): res = {} for k in kind: res |= dtypes(device=device, kind=k) return res raise ValueError(f"Unsupported kind: {kind}")