[gemini] state dict supports fp16 (#3590)

* [gemini] save state dict support fp16

* [gemini] save state dict shard support fp16

* [gemini] fix state dict

* [gemini] fix state dict
pull/3603/head
Hongxin Liu 2023-04-19 11:01:48 +08:00 committed by GitHub
parent d544ed4345
commit 12eff9eb4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 9 deletions

View File

@ -202,7 +202,12 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): def state_dict(self,
destination=None,
prefix='',
keep_vars=False,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16):
"""Returns a dictionary containing a whole state of the module. """Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Both parameters and persistent buffers (e.g. running averages) are included.
@ -221,7 +226,7 @@ class ZeroDDP(ColoDDP):
destination = OrderedDict() destination = OrderedDict()
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)
for hook in self._state_dict_hooks.values(): for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata) hook_result = hook(self, destination, prefix, local_metadata)
@ -229,7 +234,7 @@ class ZeroDDP(ColoDDP):
destination = hook_result destination = hook_result
return destination return destination
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
""" """
get gathered chunk content. get gathered chunk content.
@ -243,6 +248,8 @@ class ZeroDDP(ColoDDP):
# save parameters # save parameters
chunk_to_save_data = dict() chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk) temp_chunk = get_temp_total_chunk_on_cuda(chunk)
if torch.is_floating_point(temp_chunk):
temp_chunk = temp_chunk.to(dtype)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0]) record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
@ -255,7 +262,8 @@ class ZeroDDP(ColoDDP):
del temp_chunk del temp_chunk
return chunk_to_save_data return chunk_to_save_data
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
dtype: torch.dtype) -> Dict:
""" """
get param content from chunks. get param content from chunks.
@ -270,10 +278,10 @@ class ZeroDDP(ColoDDP):
param_to_save_data = dict() param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list) chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list: for chunk in chunk_list:
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
return param_to_save_data return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
r"""Saves module state to `destination` dictionary, containing a state r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`. submodule in :meth:`~torch.nn.Module.state_dict`.
@ -289,7 +297,8 @@ class ZeroDDP(ColoDDP):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU # get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
# get the mapping between copies and fp16 parameters # get the mapping between copies and fp16 parameters
p_mapping = dict() p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params): for p, fp32_p in zip(self.fp16_params, self.fp32_params):
@ -574,7 +583,8 @@ class ZeroDDP(ColoDDP):
prefix: str = '', prefix: str = '',
keep_vars: bool = False, keep_vars: bool = False,
max_shard_size: int = 1024, max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[OrderedDict]: only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included. Both parameters and persistent buffers (e.g. running averages) are included.
@ -607,10 +617,11 @@ class ZeroDDP(ColoDDP):
# deal with ddp ignored parameters # deal with ddp ignored parameters
gathered_param = param if keep_vars else param.detach() gathered_param = param if keep_vars else param.detach()
else: else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
fp32_param = fp16_to_fp32[param] fp32_param = fp16_to_fp32[param]
if fp32_param not in gathered_param_buffer: if fp32_param not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param) chunk = self.chunk_manager.get_chunk(fp32_param)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param) gathered_param = gathered_param_buffer.pop(fp32_param)
block = sharder.append(prefix + name, gathered_param) block = sharder.append(prefix + name, gathered_param)