Source code for heat.core.devices

"""
handle different devices. Current options: CPU (default), GPU
"""

from __future__ import annotations

import torch

from typing import Any, Optional, Union

from . import communication


__all__ = ["Device", "cpu", "get_device", "sanitize_device", "use_device"]


[docs] class Device: """ Implements a compute device. Heat can run computations on different compute devices or backends. A device describes the device type and id on which said computation should be carried out. Parameters ---------- device_type : str Represents Heat's device name device_id : int The device id torch_device : str The corresponding PyTorch device type Examples -------- >>> ht.Device("cpu", 0, "cpu:0") device(cpu:0) >>> ht.Device("gpu", 0, "cuda:0") device(gpu:0) >>> ht.Device("gpu", 0, "mps:0") # on Apple M1/M2 device(gpu:0) """ def __init__(self, device_type: str, device_id: int, torch_device: str): self.__device_type = device_type self.__device_id = device_id self.__torch_device = torch_device @property def device_type(self) -> str: """ Return the type of :class:`~heat.core.device.Device` as a string. """ return self.__device_type @property def device_id(self) -> int: """ Return the identification number of :class:`~heat.core.device.Device`. """ return self.__device_id @property def torch_device(self) -> str: """ Return the type and id of :class:`~heat.core.device.Device` as a PyTorch device string object. """ return self.__torch_device
[docs] def __repr__(self) -> str: """ Return the unambiguous information of :class:`~heat.core.device.Device`. """ return f"device({self.__str__()})"
[docs] def __str__(self) -> str: """ Return the descriptive information of :class:`~heat.core.device.Device`. """ return f"{self.device_type}:{self.device_id}"
[docs] def __eq__(self, other: Any) -> bool: """ Overloads the `==` operator for local equal check. Parameters ---------- other : Any The object to compare with """ if isinstance(other, Device): return self.device_type == other.device_type and self.device_id == other.device_id elif isinstance(other, torch.device): return self.device_type == other.type and self.device_id == other.index else: return NotImplemented
# create a CPU device singleton cpu = Device("cpu", 0, "cpu") """ The standard CPU Device Examples -------- >>> ht.cpu device(cpu:0) >>> ht.ones((2, 3), device=ht.cpu) DNDarray([[1., 1., 1.], [1., 1., 1.]], dtype=ht.float32, device=cpu:0, split=None) """ # define the default device to be the CPU __default_device = cpu # add a device string for the CPU device __device_mapping = {cpu.device_type: cpu} # add gpu support if available if torch.cuda.device_count() > 0: # GPUs are assigned round-robin to the MPI processes gpu_id = communication.MPI_WORLD.rank % torch.cuda.device_count() # create a new GPU device gpu = Device("gpu", gpu_id, f"cuda:{gpu_id}") """ The standard GPU Device Examples -------- >>> ht.cpu device(cpu:0) >>> ht.ones((2, 3), device=ht.gpu) DNDarray([[1., 1., 1.], [1., 1., 1.]], dtype=ht.float32, device=gpu:0, split=None) """ # add a GPU device string __device_mapping[gpu.device_type] = gpu __device_mapping["cuda"] = gpu # the GPU device should be exported as global symbol __all__.append("gpu") elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): # Apple MPS available gpu_id = 0 # create a new GPU device gpu = Device("gpu", gpu_id, "mps:{}".format(gpu_id)) """ The standard GPU Device on Apple M1/M2 Examples -------- >>> ht.cpu device(cpu:0) >>> ht.ones((2, 3), device=ht.gpu) DNDarray([[1., 1., 1.], [1., 1., 1.]], dtype=ht.float32, device=mps:0, split=None) """ # add a GPU device string __device_mapping[gpu.device_type] = gpu __device_mapping["mps"] = gpu # the GPU device should be exported as global symbol __all__.append("gpu")
[docs] def get_device() -> Device: """ Retrieves the currently globally set default :class:`~heat.core.device.Device`. """ return __default_device
[docs] def sanitize_device(device: Optional[Union[str, Device]] = None) -> Device: """ Sanitizes a device or device identifier, i.e. checks whether it is already an instance of :class:`~heat.core.device.Device` or a string with known device identifier and maps it to a proper :class:`~heat.core.device.Device`. Parameters ---------- device : str or Device, optional The device to be sanitized Raises ------ ValueError If the given device id is not recognized """ if device is None: return get_device() if isinstance(device, Device): return device try: return __device_mapping[device.strip().lower()] except (AttributeError, KeyError, TypeError): raise ValueError(f"Unknown device, must be one of {', '.join(__device_mapping.keys())}")
[docs] def use_device(device: Optional[Union[str, Device]] = None) -> None: """ Sets the globally used default :class:`~heat.core.device.Device`. Parameters ---------- device : str or Device The device to be set """ global __default_device __default_device = sanitize_device(device)