From 12eff9eb4cb74e0c20525aa30dc64375832afe45 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 19 Apr 2023 11:01:48 +0800 Subject: [PATCH] [gemini] state dict supports fp16 (#3590) * [gemini] save state dict support fp16 * [gemini] save state dict shard support fp16 * [gemini] fix state dict * [gemini] fix state dict --- colossalai/zero/gemini/gemini_ddp.py | 29 +++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 9a193310b..e151f1aef 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -202,7 +202,12 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + def state_dict(self, + destination=None, + prefix='', + keep_vars=False, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -221,7 +226,7 @@ class ZeroDDP(ColoDDP): destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -229,7 +234,7 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination - def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: """ get gathered chunk content. @@ -243,6 +248,8 @@ class ZeroDDP(ColoDDP): # save parameters chunk_to_save_data = dict() temp_chunk = get_temp_total_chunk_on_cuda(chunk) + if torch.is_floating_point(temp_chunk): + temp_chunk = temp_chunk.to(dtype) for tensor, tensor_info in chunk.tensors_info.items(): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) @@ -255,7 +262,8 @@ class ZeroDDP(ColoDDP): del temp_chunk return chunk_to_save_data - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, + dtype: torch.dtype) -> Dict: """ get param content from chunks. @@ -270,10 +278,10 @@ class ZeroDDP(ColoDDP): param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) return param_to_save_data - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -289,7 +297,8 @@ class ZeroDDP(ColoDDP): assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." # get copies of fp32 parameters in CPU - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) # get the mapping between copies and fp16 parameters p_mapping = dict() for p, fp32_p in zip(self.fp16_params, self.fp32_params): @@ -574,7 +583,8 @@ class ZeroDDP(ColoDDP): prefix: str = '', keep_vars: bool = False, max_shard_size: int = 1024, - only_rank_0: bool = True) -> Iterator[OrderedDict]: + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -607,10 +617,11 @@ class ZeroDDP(ColoDDP): # deal with ddp ignored parameters gathered_param = param if keep_vars else param.detach() else: + # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 fp32_param = fp16_to_fp32[param] if fp32_param not in gathered_param_buffer: chunk = self.chunk_manager.get_chunk(fp32_param) - gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) block = sharder.append(prefix + name, gathered_param)