|
|
|
@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
|
|
|
|
|
to_unpadded_tensor, |
|
|
|
|
) |
|
|
|
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
|
|
|
|
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored |
|
|
|
|
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored |
|
|
|
|
|
|
|
|
|
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager |
|
|
|
|
from .gemini_hook import GeminiZeROHook |
|
|
|
@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
pin_memory=pin_memory, |
|
|
|
|
) |
|
|
|
|
super().__init__(module) |
|
|
|
|
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) |
|
|
|
|
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module) |
|
|
|
|
self._cast_buffers() |
|
|
|
|
|
|
|
|
|
# register grad hook |
|
|
|
@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
for p in params_to_ignore: |
|
|
|
|
p._ddp_to_ignore = True |
|
|
|
|
|
|
|
|
|
def _get_non_persistent_buffers_set( |
|
|
|
|
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True |
|
|
|
|
): |
|
|
|
|
r""" |
|
|
|
|
Args: |
|
|
|
|
memo: a memo to store the set of modules already added to the result |
|
|
|
|
prefix: a prefix that will be added to the name of the module |
|
|
|
|
remove_duplicate: whether to remove the duplicated module instances in the result |
|
|
|
|
or not |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if memo is None: |
|
|
|
|
memo = set() |
|
|
|
|
self_non_persistent_set = set() |
|
|
|
|
if module not in memo: |
|
|
|
|
if remove_duplicate: |
|
|
|
|
memo.add(module) |
|
|
|
|
self_non_persistent_set = set( |
|
|
|
|
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) |
|
|
|
|
) |
|
|
|
|
for name, sub_module in module._modules.items(): |
|
|
|
|
if sub_module is None: |
|
|
|
|
continue |
|
|
|
|
submodule_prefix = prefix + ("." if prefix else "") + name |
|
|
|
|
child_non_persistent_set = self._get_non_persistent_buffers_set( |
|
|
|
|
sub_module, memo, submodule_prefix, remove_duplicate |
|
|
|
|
) |
|
|
|
|
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) |
|
|
|
|
return self_non_persistent_set |
|
|
|
|
|
|
|
|
|
def _post_forward(self): |
|
|
|
|
"""This function is only triggered for inference.""" |
|
|
|
|
access_list = list(self.chunk_manager.accessed_chunks) |
|
|
|
|