diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index b3575dbfe..ba3ddc552 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -7,11 +7,8 @@ import torch.distributed as dist from .base_accelerator import BaseAccelerator -IS_NPU_AVAILABLE = False try: import torch_npu # noqa - - IS_NPU_AVAILABLE = True except ImportError: pass diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d6610a3e1..0e1104455 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -12,7 +12,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -362,7 +362,7 @@ class GeminiPlugin(DPPluginBase): ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if IS_NPU_AVAILABLE: + if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict,