mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] fix the convert_to_torch_module bug (#2269)
parent
879df8b943
commit
af32022f74
|
@ -30,7 +30,7 @@ class GeminiManager:
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||||
|
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||||
self.policy_name = placement_policy
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
|
@ -236,7 +236,7 @@ class PlacementPolicyFactory:
|
||||||
return PlacementPolicyFactory.policies[policy_name]
|
return PlacementPolicyFactory.policies[policy_name]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_polocy_names():
|
def get_policy_names():
|
||||||
return tuple(PlacementPolicyFactory.policies.keys())
|
return tuple(PlacementPolicyFactory.policies.keys())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -360,6 +360,48 @@ class ZeroDDP(ColoDDP):
|
||||||
destination = hook_result
|
destination = hook_result
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
|
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||||
|
"""
|
||||||
|
get param content from chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_list (_type_): a list of torch.nn.Parameters
|
||||||
|
only_rank_0 (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: a dict whose key is param name and value is param with correct payload
|
||||||
|
"""
|
||||||
|
# save parameters
|
||||||
|
param_to_save_data = dict()
|
||||||
|
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||||
|
for chunk in chunk_list:
|
||||||
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||||
|
|
||||||
|
for tensor, tensor_info in chunk.tensors_info.items():
|
||||||
|
record_tensor = torch.empty([0])
|
||||||
|
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||||
|
if record_flag:
|
||||||
|
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||||
|
|
||||||
|
assert tensor not in param_to_save_data
|
||||||
|
param_to_save_data[tensor] = record_tensor
|
||||||
|
|
||||||
|
del temp_chunk
|
||||||
|
return param_to_save_data
|
||||||
|
|
||||||
|
def torch_named_parameters(self):
|
||||||
|
"""
|
||||||
|
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
|
||||||
|
It works the same as torch.Module named_parameters
|
||||||
|
"""
|
||||||
|
params_list = [p for p in self.parameters(recurse=True)]
|
||||||
|
param_to_save_data = self._get_param_to_save_data(params_list, False)
|
||||||
|
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
|
||||||
|
if p is not None:
|
||||||
|
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
|
record_parameter = param_to_save_data[p]
|
||||||
|
yield name, record_parameter
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||||
r"""Saves module state to `destination` dictionary, containing a state
|
r"""Saves module state to `destination` dictionary, containing a state
|
||||||
of the module, but not its descendants. This is called on every
|
of the module, but not its descendants. This is called on every
|
||||||
|
@ -375,23 +417,7 @@ class ZeroDDP(ColoDDP):
|
||||||
"""
|
"""
|
||||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||||
|
|
||||||
# save parameters
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||||
param_to_save_data = dict()
|
|
||||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
|
||||||
for chunk in chunk_list:
|
|
||||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
|
||||||
|
|
||||||
for tensor, tensor_info in chunk.tensors_info.items():
|
|
||||||
record_tensor = torch.empty([0])
|
|
||||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
|
||||||
if record_flag:
|
|
||||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
|
||||||
|
|
||||||
assert tensor not in param_to_save_data
|
|
||||||
param_to_save_data[tensor] = record_tensor
|
|
||||||
|
|
||||||
del temp_chunk
|
|
||||||
|
|
||||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||||
if p is not None:
|
if p is not None:
|
||||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk
|
from colossalai.gemini.chunk import Chunk
|
||||||
from colossalai.tensor import ColoTensor
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||||
return total_temp
|
return total_temp
|
||||||
|
|
||||||
|
|
||||||
|
# TODO() not work for module where two params share the same tensor.
|
||||||
def _add_param(model, name, param):
|
def _add_param(model, name, param):
|
||||||
name_list = name.split('.')
|
name_list = name.split('.')
|
||||||
module = model._modules[name_list[0]]
|
module = model._modules[name_list[0]]
|
||||||
|
@ -30,7 +30,7 @@ def _add_param(model, name, param):
|
||||||
module._parameters[name_list[-1]] = param
|
module._parameters[name_list[-1]] = param
|
||||||
|
|
||||||
|
|
||||||
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
|
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
|
||||||
"""convert_to_torch_module
|
"""convert_to_torch_module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
|
||||||
Returns:
|
Returns:
|
||||||
torch.nn.Module: a torch model contains the params of gemini_ddp_model
|
torch.nn.Module: a torch model contains the params of gemini_ddp_model
|
||||||
"""
|
"""
|
||||||
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
|
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||||
module = gemini_ddp_model.module
|
module = gemini_ddp_model.module
|
||||||
|
|
||||||
for n, p in module.named_parameters():
|
# replace ColoTensor to torch.nn.Tensor in module
|
||||||
if isinstance(p, ColoTensor):
|
for n, p in gemini_ddp_model.torch_named_parameters():
|
||||||
p.to_replicate_()
|
_add_param(module, n, p)
|
||||||
_add_param(module, n, p.data)
|
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
Loading…
Reference in New Issue