diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index 22c169577..44d784070 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -31,7 +31,7 @@ def check_installation(): found_aot_cuda_ext = _check_aot_built_cuda_extension_installed() cuda_version = _check_cuda_version() torch_version, torch_cuda_version = _check_torch_version() - colossalai_verison, torch_version_required, cuda_version_required = _parse_colossalai_version() + colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version() # if cuda_version is None, that means either # CUDA_HOME is not found, thus cannot compare the version compatibility @@ -43,33 +43,36 @@ def check_installation(): # if cuda_version or cuda_version_required is None, that means either # CUDA_HOME is not found or AOT compilation is not enabled # thus, there is no need to compare the version compatibility at all - if not cuda_version or not cuda_version_required: + if not cuda_version or not prebuilt_cuda_version_required: sys_colossalai_cuda_compatibility = None else: - sys_colossalai_cuda_compatibility = _is_compatible([cuda_version, cuda_version_required]) + sys_colossalai_cuda_compatibility = _is_compatible([cuda_version, prebuilt_cuda_version_required]) # if torch_version_required is None, that means AOT compilation is not enabled # thus there is no need to compare the versions - if torch_version_required is None: + if prebuilt_torch_version_required is None: torch_compatibility = None else: - torch_compatibility = _is_compatible([torch_version, torch_version_required]) + torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) click.echo(f'#### Installation Report ####') click.echo(f'\n------------ Environment ------------') click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}") click.echo(f"PyTorch version: {to_click_output(torch_version)}") - click.echo(f"CUDA version: {to_click_output(cuda_version)}") + click.echo(f"System CUDA version: {to_click_output(cuda_version)}") click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}") click.echo("") click.echo(f"Note:") click.echo(f"1. The table above checks the versions of the libraries/tools in the current environment") - click.echo(f"2. If the CUDA version is N/A, you can set the CUDA_HOME environment variable to locate it") + click.echo(f"2. If the System CUDA version is N/A, you can set the CUDA_HOME environment variable to locate it") + click.echo( + f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." + ) click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") - click.echo(f"PyTorch version used for AOT compilation: {to_click_output(torch_version_required)}") - click.echo(f"CUDA version used for AOT compilation: {to_click_output(cuda_version_required)}") + click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") + click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") click.echo("") click.echo(f"Note:") click.echo( @@ -169,12 +172,19 @@ def _check_torch_version(): torch_cuda_version: CUDA version required by PyTorch. """ # get torch version + # torch version can be of two formats + # - 1.13.1+cu113 + # - 1.13.1.devxxx torch_version = torch.__version__.split('+')[0] + torch_version = '.'.join(torch_version.split('.')[:3]) # get cuda version in pytorch build - torch_cuda_major = torch.version.cuda.split(".")[0] - torch_cuda_minor = torch.version.cuda.split(".")[1] - torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + except: + torch_cuda_version = None return torch_version, torch_cuda_version @@ -186,15 +196,19 @@ def _check_cuda_version(): Returns: cuda_version: CUDA version found on the system. """ + # get cuda version if CUDA_HOME is None: cuda_version = CUDA_HOME else: - raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + try: + raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + except: + cuda_version = None return cuda_version