[hotfix] removed unused flag (#5242)

pull/5278/head
Frank Lee 11 months ago committed by GitHub
parent d202cc28c0
commit 9102d655ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,11 +7,8 @@ import torch.distributed as dist
from .base_accelerator import BaseAccelerator
IS_NPU_AVAILABLE = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = True
except ImportError:
pass

@ -12,7 +12,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
@ -362,7 +362,7 @@ class GeminiPlugin(DPPluginBase):
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,

Loading…
Cancel
Save