You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/cli/check/check_installation.py

58 lines
2.0 KiB

import click
import subprocess
import torch
from torch.utils.cpp_extension import CUDA_HOME
def check_installation():
cuda_ext_installed = _check_cuda_extension_installed()
cuda_version, torch_version, torch_cuda_version, cuda_torch_compatibility = _check_cuda_torch()
click.echo(f"CUDA Version: {cuda_version}")
click.echo(f"PyTorch Version: {torch_version}")
click.echo(f"CUDA Version in PyTorch Build: {torch_cuda_version}")
click.echo(f"PyTorch CUDA Version Match: {cuda_torch_compatibility}")
click.echo(f"CUDA Extension: {cuda_ext_installed}")
def _check_cuda_extension_installed():
try:
import colossal_C
is_cuda_extension_installed = u'\u2713'
except ImportError:
is_cuda_extension_installed = 'x'
return is_cuda_extension_installed
def _check_cuda_torch():
# get cuda version
if CUDA_HOME is None:
cuda_version = 'N/A (CUDA_HOME is not set)'
else:
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
# get torch version
torch_version = torch.__version__
# get cuda version in pytorch build
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
# check version compatiblity
cuda_torch_compatibility = 'x'
if CUDA_HOME:
if torch_cuda_major == bare_metal_major:
if torch_cuda_minor == bare_metal_minor:
cuda_torch_compatibility = u'\u2713'
else:
cuda_torch_compatibility = u'\u2713 (minor version mismatch)'
return cuda_version, torch_version, torch_cuda_version, cuda_torch_compatibility