diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 143eeae58..9dead5b4b 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -2,8 +2,10 @@ from .spec import ComputePattern, ParallelAction, TensorSpec from .op_wrapper import ( colo_op_impl,) from .colo_tensor import ColoTensor -from .utils import convert_parameter +from .utils import convert_parameter, named_params_with_colotensor from ._ops import * -__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', - 'TensorSpec', 'ParallelAction'] +__all__ = [ + 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', + 'named_params_with_colotensor' +] diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index 1430e5191..5abce8ca1 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -2,6 +2,57 @@ import torch from colossalai.tensor.colo_tensor import ColoTensor +from typing import Iterator, Tuple, Union +import torch.nn as nn +from colossalai.tensor import ColoTensor + + +# The function is credited to PyTorch Team +def named_params_with_colotensor( + module: nn.Module, + prefix: str = '', + recurse: bool = True, +) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + r"""Returns an iterator over module parameters (together with the + ColoTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torchshard._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (string, Union[Tensor, ColoTensor]): Tuple containing + the name and parameter (or ColoTensor parameter) + + Example:: + + >>> model = torch.nn.Linear(*linear_size) + >>> delattr(model.weight) + >>> setattr(model.weight, ColoTensor(...)) + >>> for name, param in named_params_with_colotensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ColoTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val + def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: return ColoTensor(tensor) diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_net_tp.py index 0d5ea848d..45b9902e9 100644 --- a/tests/test_tensor/test_net_tp.py +++ b/tests/test_tensor/test_net_tp.py @@ -7,6 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils import ColoInitContext +from colossalai.tensor import named_params_with_colotensor from functools import partial @@ -19,6 +20,8 @@ def run_simple_net(): with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) + for param in named_params_with_colotensor(model): + print(param) # we set the Specs for weight of each linear. # model.proj1.weight.set_spec('1Drow') # model.proj2.weight.set_spec('1Drow')