mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] removed unused flag (#5242)
parent
d202cc28c0
commit
9102d655ab
|
@ -7,11 +7,8 @@ import torch.distributed as dist
|
|||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
IS_NPU_AVAILABLE = False
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
|
||||
IS_NPU_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
|
@ -362,7 +362,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
if IS_NPU_AVAILABLE:
|
||||
if get_accelerator().name == "npu":
|
||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
self.gemini_config = dict(
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
|
|
Loading…
Reference in New Issue