[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext (#808)

* init fp16 param directly

* polish code
pull/810/head
ver217 2022-04-19 16:16:48 +08:00 committed by GitHub
parent 227d1cd4b3
commit dd92b90a68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 4 deletions

View File

@ -23,14 +23,17 @@ def _substitute_init_recursively(cls, func):
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self):
pass
def __init__(self, default_dtype: Optional[torch.dtype] = None):
self._old_default_dtype = None
self._default_dtype = default_dtype
def __enter__(self):
r"""
Enter the context scope.
"""
if self._default_dtype is not None:
self._old_default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._default_dtype)
def preprocess_after(f):
@functools.wraps(f)
@ -61,6 +64,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
def __exit__(self, exc_type, exc_value, traceback):
if self._default_dtype is not None:
torch.set_default_dtype(self._old_default_dtype)
def _disable_class(cls):
cls.__init__ = cls._old_init
@ -123,6 +128,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
@ -131,9 +137,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy: BaseShardStrategy,
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__()
super().__init__(default_dtype=default_dtype)
self.shard_strategy = shard_strategy
self.param_list = []
self.model_numel_tensor = model_numel_tensor