[Tensor] simplify named param (#928)

* simplify ColoModulize

* simplify ColoModulize

* polish

* polish
pull/933/head
Ziyue Jiang 3 years ago committed by GitHub
parent 32a45cd7ef
commit dfc88b85ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -90,56 +90,28 @@ def ColoModulize(module):
Replacing the parameters() and named_parameters() with our customized ones
"""
def named_params_with_colotensor(
module: nn.Module,
prefix: str = '',
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
memo = set()
for mod_prefix, mod in modules:
# find all colotensors tensor params
for name, val in vars(mod).items():
if isinstance(val, ColoTensor) and val not in memo:
memo.add(val)
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val
# find all nn.Parameters
for name, val in module.old_named_parameters(recurse=recurse):
yield name, val
def fake_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
for p in module.old_parameters(*args, **kargs):
if isinstance(p, ColoTensor):
yield p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield p
def fake_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
for name, p in module.old_named_parameters(*args, **kargs):
if isinstance(p, ColoTensor):
yield name, p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield name, p
def colo_parameters(self, *args, **kargs):
for _, p in named_params_with_colotensor(self, *args, **kargs):
yield p
def colo_named_parameters(self, *args, **kargs):
for name, p in named_params_with_colotensor(self, *args, **kargs):
yield name, p
module.old_named_parameters = module.named_parameters
module.old_parameters = module.parameters
funcType = types.MethodType
module.parameters = funcType(fake_parameters, module)
module.named_parameters = funcType(fake_named_parameters, module)
module.colo_parameters = funcType(colo_parameters, module)
module.colo_named_parameters = funcType(colo_named_parameters, module)
module.colo_parameters = module.old_parameters
module.colo_named_parameters = module.old_named_parameters
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
@ -154,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor

@ -1,3 +1,4 @@
from colossalai.tensor.colo_parameter import ColoParameter
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
@ -371,7 +372,13 @@ def _run_pretrain_load():
dict_col = {}
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
for name, param in model.named_parameters():
c1 = 0
c2 = 0
for name, param in model.colo_named_parameters():
if isinstance(param, ColoParameter):
c1 = c1 + 1
else:
c2 = c2 + 1
dict_col[name] = param
for name, param in dict_pretrained.items():
@ -416,4 +423,5 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model()
_test_pretrain_load(4)
# _test_pretrain_load(4)
_run_pretrain_load()

Loading…
Cancel
Save