From 9102d655ab243ec5d3b2c2cf3bbfa62866373189 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 9 Jan 2024 14:57:07 +0800 Subject: [PATCH] [hotfix] removed unused flag (#5242) --- colossalai/accelerator/npu_accelerator.py | 3 --- colossalai/booster/plugin/gemini_plugin.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index b3575dbfe..ba3ddc552 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -7,11 +7,8 @@ import torch.distributed as dist from .base_accelerator import BaseAccelerator -IS_NPU_AVAILABLE = False try: import torch_npu # noqa - - IS_NPU_AVAILABLE = True except ImportError: pass diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d6610a3e1..0e1104455 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -12,7 +12,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -362,7 +362,7 @@ class GeminiPlugin(DPPluginBase): ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if IS_NPU_AVAILABLE: + if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict,