Source code for heat.core._config

"""
Everything you need to know about the configuration of Heat
"""

import torch
import platform
import mpi4py
import subprocess
import os
import warnings
import re
import dataclasses
from enum import Enum


[docs] class MPILibrary(Enum): OpenMPI = "ompi" IntelMPI = "impi" MVAPICH = "mvapich" MPICH = "mpich" CrayMPI = "craympi" ParaStationMPI = "psmpi" Other = "other"
[docs] @dataclasses.dataclass class MPILibraryInfo: name: MPILibrary version: str
[docs] def _get_mpi_library() -> MPILibraryInfo: library = mpi4py.MPI.Get_library_version().split() match library: case ["Open", "MPI", *_]: return MPILibraryInfo(MPILibrary.OpenMPI, library[2]) case ["Intel(R)", "MPI", *_]: return MPILibraryInfo(MPILibrary.IntelMPI, library[3]) case ["MPICH", "Version:", *_]: return MPILibraryInfo(MPILibrary.MPICH, library[2]) case ["MVAPICH", "Version:", *_]: return MPILibraryInfo(MPILibrary.MVAPICH, library[2]) case ["===", "ParaStation", "MPI", *_]: return MPILibraryInfo(MPILibrary.ParaStationMPI, library[3]) case _: return MPILibraryInfo(MPILibrary.Other, "unknown")
[docs] def _check_gpu_aware_mpi(library: MPILibraryInfo) -> tuple[bool, bool]: match library.name: case MPILibrary.OpenMPI: try: parsable_ompi_info = subprocess.check_output( ["ompi_info", "--parsable", "--all"] ).decode("utf-8") ompi_info = subprocess.check_output(["ompi_info"]).decode("utf-8") # Check for CUDA support flag cuda_support_flag = "mpi_built_with_cuda_support:value:true" in parsable_ompi_info # Check for extensions match = re.search(r"MPI extensions: (.*)", ompi_info) extensions = [ext.strip() for ext in match.group(0).split(":")[1].split(",")] cuda = cuda_support_flag and "cuda" in extensions if library.version.startswith("v4."): rocm = cuda elif library.version.startswith("v5."): rocm = "rocm" in extensions or "hip" in extensions # Seems to be broken, disabled by default for now # return cuda, rocm return False, False except Exception as e: # noqa E722 return False, False case MPILibrary.IntelMPI: return False, False case MPILibrary.MVAPICH: cuda = os.environ.get("MV2_USE_CUDA") == "1" rocm = os.environ.get("MV2_USE_ROCM") == "1" return cuda, rocm case MPILibrary.MPICH: cuda = os.environ.get("MPIR_CVAR_ENABLE_HCOLL") == "1" rocm = False return cuda, rocm case MPILibrary.CrayMPI: cuda = os.environ.get("MPICH_GPU_SUPPORT_ENABLED") == "1" rocm = os.environ.get("MPICH_GPU_SUPPORT_ENABLED") == "1" return cuda, rocm case MPILibrary.ParaStationMPI: cuda = os.environ.get("PSP_CUDA") == "1" rocm = False return cuda, rocm case _: return False, False
PLATFORM = platform.platform() TORCH_VERSION = torch.__version__ TORCH_CUDA_IS_AVAILABLE = torch.cuda.is_available() CUDA_IS_ACTUALLY_ROCM = "rocm" in TORCH_VERSION mpi_library = _get_mpi_library() CUDA_AWARE_MPI, ROCM_AWARE_MPI = _check_gpu_aware_mpi(mpi_library) GPU_AWARE_MPI = False # warn the user if CUDA/ROCm-aware MPI is not available, but PyTorch can use GPUs with CUDA/ROCm if TORCH_CUDA_IS_AVAILABLE: if not CUDA_IS_ACTUALLY_ROCM and not CUDA_AWARE_MPI: warnings.warn( f"Heat has CUDA GPU-support (PyTorch version {TORCH_VERSION} and `torch.cuda.is_available() = True`), but CUDA-awareness of MPI could not be detected. This may lead to performance degradation as direct MPI-communication between GPUs is not possible.", UserWarning, ) elif CUDA_IS_ACTUALLY_ROCM and not ROCM_AWARE_MPI: warnings.warn( f"Heat has ROCm GPU-support (PyTorch version {TORCH_VERSION} and `torch.cuda.is_available() = True`), but ROCm-awareness of MPI could not be detected. This may lead to performance degradation as direct MPI-communication between GPUs is not possible.", UserWarning, ) else: GPU_AWARE_MPI = True