mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] get named parameters for model using ColoTensors (#874)
parent
2883040286
commit
e43f83aa5c
|
@ -2,8 +2,10 @@ from .spec import ComputePattern, ParallelAction, TensorSpec
|
||||||
from .op_wrapper import (
|
from .op_wrapper import (
|
||||||
colo_op_impl,)
|
colo_op_impl,)
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
from .utils import convert_parameter
|
from .utils import convert_parameter, named_params_with_colotensor
|
||||||
from ._ops import *
|
from ._ops import *
|
||||||
|
|
||||||
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern',
|
__all__ = [
|
||||||
'TensorSpec', 'ParallelAction']
|
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||||
|
'named_params_with_colotensor'
|
||||||
|
]
|
||||||
|
|
|
@ -2,6 +2,57 @@ import torch
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
|
|
||||||
|
from typing import Iterator, Tuple, Union
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
|
# The function is credited to PyTorch Team
|
||||||
|
def named_params_with_colotensor(
|
||||||
|
module: nn.Module,
|
||||||
|
prefix: str = '',
|
||||||
|
recurse: bool = True,
|
||||||
|
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
|
||||||
|
r"""Returns an iterator over module parameters (together with the
|
||||||
|
ColoTensor parameters), yielding both the name of the parameter
|
||||||
|
as well as the parameter itself. This is typically passed to a
|
||||||
|
:class:torchshard._shard.sharded_optim.ShardedOptimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str): prefix to prepend to all parameter names.
|
||||||
|
recurse (bool): if True, then yields parameters of this module
|
||||||
|
and all submodules. Otherwise, yields only parameters that
|
||||||
|
are direct members of this module.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
(string, Union[Tensor, ColoTensor]): Tuple containing
|
||||||
|
the name and parameter (or ColoTensor parameter)
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> model = torch.nn.Linear(*linear_size)
|
||||||
|
>>> delattr(model.weight)
|
||||||
|
>>> setattr(model.weight, ColoTensor(...))
|
||||||
|
>>> for name, param in named_params_with_colotensor(model):
|
||||||
|
>>> if name in ['weight']:
|
||||||
|
>>> print(param.size())
|
||||||
|
|
||||||
|
"""
|
||||||
|
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
|
||||||
|
|
||||||
|
memo = set()
|
||||||
|
for mod_prefix, mod in modules:
|
||||||
|
# find all sharded tensor params
|
||||||
|
for name, val in vars(mod).items():
|
||||||
|
if isinstance(val, ColoTensor) and val not in memo:
|
||||||
|
memo.add(val)
|
||||||
|
name = mod_prefix + ('.' if mod_prefix else '') + name
|
||||||
|
yield name, val
|
||||||
|
|
||||||
|
# find all nn.Parameters
|
||||||
|
for name, val in module.named_parameters():
|
||||||
|
yield name, val
|
||||||
|
|
||||||
|
|
||||||
def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
|
def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
|
||||||
return ColoTensor(tensor)
|
return ColoTensor(tensor)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
|
from colossalai.tensor import named_params_with_colotensor
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -19,6 +20,8 @@ def run_simple_net():
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
|
for param in named_params_with_colotensor(model):
|
||||||
|
print(param)
|
||||||
# we set the Specs for weight of each linear.
|
# we set the Specs for weight of each linear.
|
||||||
# model.proj1.weight.set_spec('1Drow')
|
# model.proj1.weight.set_spec('1Drow')
|
||||||
# model.proj2.weight.set_spec('1Drow')
|
# model.proj2.weight.set_spec('1Drow')
|
||||||
|
|
Loading…
Reference in New Issue