mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] simplify named param (#928)
* simplify ColoModulize * simplify ColoModulize * polish * polishpull/933/head
parent
32a45cd7ef
commit
dfc88b85ea
|
@ -90,56 +90,28 @@ def ColoModulize(module):
|
||||||
Replacing the parameters() and named_parameters() with our customized ones
|
Replacing the parameters() and named_parameters() with our customized ones
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def named_params_with_colotensor(
|
|
||||||
module: nn.Module,
|
|
||||||
prefix: str = '',
|
|
||||||
recurse: bool = True,
|
|
||||||
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
|
|
||||||
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
|
|
||||||
|
|
||||||
memo = set()
|
|
||||||
for mod_prefix, mod in modules:
|
|
||||||
# find all colotensors tensor params
|
|
||||||
for name, val in vars(mod).items():
|
|
||||||
if isinstance(val, ColoTensor) and val not in memo:
|
|
||||||
memo.add(val)
|
|
||||||
name = mod_prefix + ('.' if mod_prefix else '') + name
|
|
||||||
yield name, val
|
|
||||||
|
|
||||||
# find all nn.Parameters
|
|
||||||
for name, val in module.old_named_parameters(recurse=recurse):
|
|
||||||
yield name, val
|
|
||||||
|
|
||||||
def fake_parameters(self, *args, **kargs):
|
def fake_parameters(self, *args, **kargs):
|
||||||
for name, p in named_params_with_colotensor(self, *args, **kargs):
|
for p in module.old_parameters(*args, **kargs):
|
||||||
if isinstance(p, ColoTensor):
|
if isinstance(p, ColoTensor):
|
||||||
yield p.torch_tensor()
|
yield p.torch_tensor()
|
||||||
elif isinstance(p, torch.Tensor):
|
elif isinstance(p, torch.Tensor):
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
def fake_named_parameters(self, *args, **kargs):
|
def fake_named_parameters(self, *args, **kargs):
|
||||||
for name, p in named_params_with_colotensor(self, *args, **kargs):
|
for name, p in module.old_named_parameters(*args, **kargs):
|
||||||
if isinstance(p, ColoTensor):
|
if isinstance(p, ColoTensor):
|
||||||
yield name, p.torch_tensor()
|
yield name, p.torch_tensor()
|
||||||
elif isinstance(p, torch.Tensor):
|
elif isinstance(p, torch.Tensor):
|
||||||
yield name, p
|
yield name, p
|
||||||
|
|
||||||
def colo_parameters(self, *args, **kargs):
|
|
||||||
for _, p in named_params_with_colotensor(self, *args, **kargs):
|
|
||||||
yield p
|
|
||||||
|
|
||||||
def colo_named_parameters(self, *args, **kargs):
|
|
||||||
for name, p in named_params_with_colotensor(self, *args, **kargs):
|
|
||||||
yield name, p
|
|
||||||
|
|
||||||
module.old_named_parameters = module.named_parameters
|
module.old_named_parameters = module.named_parameters
|
||||||
module.old_parameters = module.parameters
|
module.old_parameters = module.parameters
|
||||||
|
|
||||||
funcType = types.MethodType
|
funcType = types.MethodType
|
||||||
module.parameters = funcType(fake_parameters, module)
|
module.parameters = funcType(fake_parameters, module)
|
||||||
module.named_parameters = funcType(fake_named_parameters, module)
|
module.named_parameters = funcType(fake_named_parameters, module)
|
||||||
module.colo_parameters = funcType(colo_parameters, module)
|
module.colo_parameters = module.old_parameters
|
||||||
module.colo_named_parameters = funcType(colo_named_parameters, module)
|
module.colo_named_parameters = module.old_named_parameters
|
||||||
module._colo_visited = True
|
module._colo_visited = True
|
||||||
|
|
||||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
@ -154,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self._lazy_memory_allocate = lazy_memory_allocate
|
self._lazy_memory_allocate = lazy_memory_allocate
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
|
|
||||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -371,7 +372,13 @@ def _run_pretrain_load():
|
||||||
dict_col = {}
|
dict_col = {}
|
||||||
for name, param in model_pretrained.named_parameters():
|
for name, param in model_pretrained.named_parameters():
|
||||||
dict_pretrained[name] = param
|
dict_pretrained[name] = param
|
||||||
for name, param in model.named_parameters():
|
c1 = 0
|
||||||
|
c2 = 0
|
||||||
|
for name, param in model.colo_named_parameters():
|
||||||
|
if isinstance(param, ColoParameter):
|
||||||
|
c1 = c1 + 1
|
||||||
|
else:
|
||||||
|
c2 = c2 + 1
|
||||||
dict_col[name] = param
|
dict_col[name] = param
|
||||||
|
|
||||||
for name, param in dict_pretrained.items():
|
for name, param in dict_pretrained.items():
|
||||||
|
@ -416,4 +423,5 @@ if __name__ == '__main__':
|
||||||
# test_model_parameters()
|
# test_model_parameters()
|
||||||
# test_colo_optimizer()
|
# test_colo_optimizer()
|
||||||
# test_model()
|
# test_model()
|
||||||
_test_pretrain_load(4)
|
# _test_pretrain_load(4)
|
||||||
|
_run_pretrain_load()
|
||||||
|
|
Loading…
Reference in New Issue