mirror of https://github.com/hpcaitech/ColossalAI
[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext (#808)
* init fp16 param directly * polish codepull/810/head
parent
227d1cd4b3
commit
dd92b90a68
|
@ -23,14 +23,17 @@ def _substitute_init_recursively(cls, func):
|
|||
|
||||
class InsertPostInitMethodToModuleSubClasses(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, default_dtype: Optional[torch.dtype] = None):
|
||||
self._old_default_dtype = None
|
||||
self._default_dtype = default_dtype
|
||||
|
||||
def __enter__(self):
|
||||
r"""
|
||||
Enter the context scope.
|
||||
"""
|
||||
|
||||
if self._default_dtype is not None:
|
||||
self._old_default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self._default_dtype)
|
||||
def preprocess_after(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
|
@ -61,6 +64,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
||||
if self._default_dtype is not None:
|
||||
torch.set_default_dtype(self._old_default_dtype)
|
||||
def _disable_class(cls):
|
||||
cls.__init__ = cls._old_init
|
||||
|
||||
|
@ -123,6 +128,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
|
@ -131,9 +137,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
shard_strategy: BaseShardStrategy,
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
|
|
Loading…
Reference in New Issue