[Gemini] fix the convert_to_torch_module bug (#2269)

pull/2284/head
Jiarui Fang 2 years ago committed by GitHub
parent 879df8b943
commit af32022f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,7 +30,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names() assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager

@ -236,7 +236,7 @@ class PlacementPolicyFactory:
return PlacementPolicyFactory.policies[policy_name] return PlacementPolicyFactory.policies[policy_name]
@staticmethod @staticmethod
def get_polocy_names(): def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys()) return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod @staticmethod

@ -360,24 +360,20 @@ class ZeroDDP(ColoDDP):
destination = hook_result destination = hook_result
return destination return destination
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
r"""Saves module state to `destination` dictionary, containing a state """
of the module, but not its descendants. This is called on every get param content from chunks.
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args: Args:
destination (dict): a dict where state will be stored param_list (_type_): a list of torch.nn.Parameters
prefix (str): the prefix for parameters and buffers used in this only_rank_0 (_type_): _description_
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters # save parameters
param_to_save_data = dict() param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list: for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk) temp_chunk = get_temp_total_chunk_on_cuda(chunk)
@ -391,7 +387,37 @@ class ZeroDDP(ColoDDP):
param_to_save_data[tensor] = record_tensor param_to_save_data[tensor] = record_tensor
del temp_chunk del temp_chunk
return param_to_save_data
def torch_named_parameters(self):
"""
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
It works the same as torch.Module named_parameters
"""
params_list = [p for p in self.parameters(recurse=True)]
param_to_save_data = self._get_param_to_save_data(params_list, False)
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
if p is not None:
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[p]
yield name, record_parameter
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None: if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)

@ -2,7 +2,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.gemini.chunk import Chunk from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return total_temp return total_temp
# TODO() not work for module where two params share the same tensor.
def _add_param(model, name, param): def _add_param(model, name, param):
name_list = name.split('.') name_list = name.split('.')
module = model._modules[name_list[0]] module = model._modules[name_list[0]]
@ -30,7 +30,7 @@ def _add_param(model, name, param):
module._parameters[name_list[-1]] = param module._parameters[name_list[-1]] = param
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
"""convert_to_torch_module """convert_to_torch_module
Args: Args:
@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
Returns: Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model torch.nn.Module: a torch model contains the params of gemini_ddp_model
""" """
from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP)
module = gemini_ddp_model.module module = gemini_ddp_model.module
for n, p in module.named_parameters(): # replace ColoTensor to torch.nn.Tensor in module
if isinstance(p, ColoTensor): for n, p in gemini_ddp_model.torch_named_parameters():
p.to_replicate_() _add_param(module, n, p)
_add_param(module, n, p.data)
return module return module

Loading…
Cancel
Save