|
|
|
@ -12,7 +12,7 @@ from torch.optim import Optimizer
|
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
|
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
|
|
|
|
from colossalai.checkpoint_io.utils import (
|
|
|
|
|
get_model_base_filenames,
|
|
|
|
@ -362,7 +362,7 @@ class GeminiPlugin(DPPluginBase):
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
|
|
|
|
if IS_NPU_AVAILABLE:
|
|
|
|
|
if get_accelerator().name == "npu":
|
|
|
|
|
assert placement_policy == "static", "NPU only supports static placement policy"
|
|
|
|
|
self.gemini_config = dict(
|
|
|
|
|
chunk_config_dict=chunk_config_dict,
|
|
|
|
|