Browse Source

[Tensor] overriding paramters() for Module using ColoTensor (#889)

pull/895/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
26c49639d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      colossalai/tensor/colo_tensor.py
  2. 71
      colossalai/utils/model/colo_init_context.py
  3. 2
      tests/test_tensor/test_model.py

7
colossalai/tensor/colo_tensor.py

@ -165,7 +165,12 @@ class ColoTensor(object):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def __add__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, torch.Tensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)

71
colossalai/utils/model/colo_init_context.py

@ -1,10 +1,68 @@
from colossalai.utils.cuda import get_current_device
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
# from colossalai.logging import get_dist_logger
from colossalai.tensor import ColoTensor
import types
# _orig_torch_empty = torch.empty
from torch import nn
from typing import Iterator, Tuple, Union
def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
"""
def named_params_with_colotensor(
module: nn.Module,
prefix: str = '',
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
memo = set()
for mod_prefix, mod in modules:
# find all colotensors tensor params
for name, val in vars(mod).items():
if isinstance(val, ColoTensor) and val not in memo:
memo.add(val)
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val
# find all nn.Parameters
for name, val in module.old_named_parameters(recurse=recurse):
yield name, val
def fake_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
if isinstance(p, ColoTensor):
yield p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield p
def fake_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
if isinstance(p, ColoTensor):
yield name, p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield name, p
def colo_parameters(self, *args, **kargs):
for _, p in named_params_with_colotensor(self, *args, **kargs):
yield p
def colo_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
yield name, p
module.old_named_parameters = module.named_parameters
module.old_parameters = module.parameters
funcType = types.MethodType
module.parameters = funcType(fake_parameters, module)
module.named_parameters = funcType(fake_named_parameters, module)
module.colo_parameters = funcType(colo_parameters, module)
module.colo_named_parameters = funcType(colo_named_parameters, module)
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
@ -24,8 +82,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
if hasattr(module, '_colo_visited'):
return
name_list = []
for name, param in module.named_parameters():
for name, param in module.named_parameters(recurse=False):
if isinstance(param, ColoTensor):
continue
name_list.append((name, param))
@ -35,3 +96,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(module, name)
setattr(module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))
ColoModulize(module)

2
tests/test_tensor/test_model.py

@ -48,7 +48,7 @@ def run_1d_row_tp():
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in named_params_with_colotensor(model):
for name, p in model.colo_named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:

Loading…
Cancel
Save