From af32022f740c96d469e0970f54894e95cfefafde Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 3 Jan 2023 15:55:35 +0800 Subject: [PATCH] [Gemini] fix the convert_to_torch_module bug (#2269) --- colossalai/gemini/gemini_mgr.py | 2 +- colossalai/gemini/placement_policy.py | 2 +- colossalai/nn/parallel/data_parallel.py | 60 ++++++++++++++++++------- colossalai/nn/parallel/utils.py | 13 +++--- 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 541762a72..08961b958 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -30,7 +30,7 @@ class GeminiManager: def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: - assert placement_policy in PlacementPolicyFactory.get_polocy_names() + assert placement_policy in PlacementPolicyFactory.get_policy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 50004ec35..fed1cc298 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -236,7 +236,7 @@ class PlacementPolicyFactory: return PlacementPolicyFactory.policies[policy_name] @staticmethod - def get_polocy_names(): + def get_policy_names(): return tuple(PlacementPolicyFactory.policies.keys()) @staticmethod diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 8bd91050f..cbef6f532 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -360,6 +360,48 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + """ + get param content from chunks. + + Args: + param_list (_type_): a list of torch.nn.Parameters + only_rank_0 (_type_): _description_ + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in param_to_save_data + param_to_save_data[tensor] = record_tensor + + del temp_chunk + return param_to_save_data + + def torch_named_parameters(self): + """ + get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload. + It works the same as torch.Module named_parameters + """ + params_list = [p for p in self.parameters(recurse=True)] + param_to_save_data = self._get_param_to_save_data(params_list, False) + for (name, _), p in zip(self.named_parameters(recurse=True), params_list): + if p is not None: + assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[p] + yield name, record_parameter + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every @@ -375,23 +417,7 @@ class ZeroDDP(ColoDDP): """ assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - # save parameters - param_to_save_data = dict() - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk - + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 844439cde..e514146ce 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist from colossalai.gemini.chunk import Chunk -from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device @@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): return total_temp +# TODO() not work for module where two params share the same tensor. def _add_param(model, name, param): name_list = name.split('.') module = model._modules[name_list[0]] @@ -30,7 +30,7 @@ def _add_param(model, name, param): module._parameters[name_list[-1]] = param -def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: +def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module: """convert_to_torch_module Args: @@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: Returns: torch.nn.Module: a torch model contains the params of gemini_ddp_model """ + from colossalai.nn.parallel import GeminiDDP + assert isinstance(gemini_ddp_model, GeminiDDP) module = gemini_ddp_model.module - for n, p in module.named_parameters(): - if isinstance(p, ColoTensor): - p.to_replicate_() - _add_param(module, n, p.data) + # replace ColoTensor to torch.nn.Tensor in module + for n, p in gemini_ddp_model.torch_named_parameters(): + _add_param(module, n, p) return module