[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext (#808)

* init fp16 param directly

* polish code
pull/810/head
ver217 2022-04-19 16:16:48 +08:00 committed by GitHub
parent 227d1cd4b3
commit dd92b90a68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 4 deletions

View File

@ -23,14 +23,17 @@ def _substitute_init_recursively(cls, func):
class InsertPostInitMethodToModuleSubClasses(object): class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self): def __init__(self, default_dtype: Optional[torch.dtype] = None):
pass self._old_default_dtype = None
self._default_dtype = default_dtype
def __enter__(self): def __enter__(self):
r""" r"""
Enter the context scope. Enter the context scope.
""" """
if self._default_dtype is not None:
self._old_default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._default_dtype)
def preprocess_after(f): def preprocess_after(f):
@functools.wraps(f) @functools.wraps(f)
@ -61,6 +64,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if self._default_dtype is not None:
torch.set_default_dtype(self._old_default_dtype)
def _disable_class(cls): def _disable_class(cls):
cls.__init__ = cls._old_init cls.__init__ = cls._old_init
@ -123,6 +128,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy (BaseShardStrategy): Shard strategy instance. shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
""" """
@ -131,9 +137,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
seed: int = 2**10 - 1, seed: int = 2**10 - 1,
shard_param: bool = False, shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__() super().__init__(default_dtype=default_dtype)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.param_list = [] self.param_list = []
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor