2023-01-06 05:41:19 +00:00
|
|
|
from collections import OrderedDict
|
|
|
|
from copy import copy
|
|
|
|
from typing import Optional, Set
|
|
|
|
|
2022-12-12 07:39:31 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2023-01-06 05:41:19 +00:00
|
|
|
import torch.nn as nn
|
2022-12-12 07:39:31 +00:00
|
|
|
|
2024-01-09 02:20:05 +00:00
|
|
|
from colossalai.accelerator import get_accelerator
|
2022-12-12 07:39:31 +00:00
|
|
|
|
2023-04-04 05:48:16 +00:00
|
|
|
from .chunk import Chunk
|
|
|
|
|
2022-12-12 07:39:31 +00:00
|
|
|
|
2023-10-12 02:39:08 +00:00
|
|
|
def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
|
2022-12-12 07:39:31 +00:00
|
|
|
if chunk.is_gathered:
|
|
|
|
return chunk.cuda_global_chunk
|
|
|
|
|
|
|
|
if chunk.cuda_shard is not None:
|
|
|
|
shard_temp = chunk.cuda_shard
|
|
|
|
else:
|
2024-01-09 02:20:05 +00:00
|
|
|
shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
|
2022-12-12 07:39:31 +00:00
|
|
|
|
2023-10-12 02:39:08 +00:00
|
|
|
shard_temp = shard_temp.to(dtype)
|
|
|
|
|
2024-01-09 02:20:05 +00:00
|
|
|
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
|
2022-12-12 07:39:31 +00:00
|
|
|
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
|
|
|
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
|
|
|
|
|
|
|
return total_temp
|
2022-12-20 02:19:36 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""):
|
|
|
|
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules."""
|
2023-01-06 05:41:19 +00:00
|
|
|
if memo is None:
|
|
|
|
memo = set()
|
|
|
|
if module not in memo:
|
|
|
|
for name, submodule in module._modules.items():
|
|
|
|
if submodule is None:
|
|
|
|
continue
|
2023-09-19 06:20:26 +00:00
|
|
|
submodule_prefix = prefix + ("." if prefix else "") + name
|
2023-01-06 05:41:19 +00:00
|
|
|
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
|
|
|
|
yield m
|
|
|
|
|
|
|
|
memo.add(module)
|
|
|
|
yield prefix, module
|
2022-12-20 02:19:36 +00:00
|
|
|
|
|
|
|
|
2023-01-06 05:41:19 +00:00
|
|
|
def _get_shallow_copy_model(model: nn.Module):
|
|
|
|
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
|
|
|
But the new submodule and the old submodule share all attributes.
|
|
|
|
"""
|
2023-01-09 09:41:38 +00:00
|
|
|
old_to_new = dict()
|
2023-01-06 05:41:19 +00:00
|
|
|
for name, module in _get_dfs_module_list(model):
|
|
|
|
new_module = copy(module)
|
|
|
|
new_module._modules = OrderedDict()
|
|
|
|
for subname, submodule in module._modules.items():
|
|
|
|
if submodule is None:
|
|
|
|
continue
|
2023-01-09 09:41:38 +00:00
|
|
|
setattr(new_module, subname, old_to_new[submodule])
|
|
|
|
old_to_new[module] = new_module
|
|
|
|
return old_to_new[model]
|
2023-01-06 05:41:19 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def get_static_torch_model(
|
|
|
|
zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
|
|
|
|
) -> torch.nn.Module:
|
2023-08-24 01:29:25 +00:00
|
|
|
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
|
|
|
You should notice that the original GeminiDDP model is not modified.
|
2023-01-06 05:41:19 +00:00
|
|
|
Thus, you can use the original model in further training.
|
|
|
|
But you should not use the returned torch model to train, this can cause unexpected errors.
|
2022-12-20 02:19:36 +00:00
|
|
|
|
|
|
|
Args:
|
2023-08-24 01:29:25 +00:00
|
|
|
zero_ddp_model (GeminiDDP): a zero ddp model
|
2023-01-06 05:41:19 +00:00
|
|
|
device (torch.device): the device of the final torch model
|
|
|
|
dtype (torch.dtype): the dtype of the final torch model
|
2023-06-07 08:08:37 +00:00
|
|
|
only_rank_0 (bool): if True, only rank0 has the converted torch model
|
2022-12-20 02:19:36 +00:00
|
|
|
|
|
|
|
Returns:
|
2023-01-06 05:41:19 +00:00
|
|
|
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
2022-12-20 02:19:36 +00:00
|
|
|
"""
|
2023-08-24 01:29:25 +00:00
|
|
|
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2023-08-24 01:29:25 +00:00
|
|
|
assert isinstance(zero_ddp_model, GeminiDDP)
|
2022-12-20 02:19:36 +00:00
|
|
|
|
2023-01-31 02:40:39 +00:00
|
|
|
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
2023-01-09 06:35:14 +00:00
|
|
|
colo_model = zero_ddp_model.module
|
2023-01-06 05:41:19 +00:00
|
|
|
torch_model = _get_shallow_copy_model(colo_model)
|
|
|
|
|
|
|
|
if not only_rank_0 or dist.get_rank() == 0:
|
2023-09-19 06:20:26 +00:00
|
|
|
for (name, colo_module), (_, torch_module) in zip(
|
|
|
|
_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)
|
|
|
|
):
|
2023-01-06 05:41:19 +00:00
|
|
|
# clean the parameter list of the new torch module
|
|
|
|
torch_module._parameters = OrderedDict()
|
|
|
|
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
|
|
|
# get the full name of the parameter
|
2023-09-19 06:20:26 +00:00
|
|
|
full_param_name = name + ("." if name else "") + sufix_param_name
|
|
|
|
assert (
|
|
|
|
full_param_name in state_dict
|
|
|
|
), f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
|
2023-01-31 02:40:39 +00:00
|
|
|
state_param = state_dict[full_param_name]
|
|
|
|
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
2023-01-06 05:41:19 +00:00
|
|
|
|
|
|
|
setattr(torch_module, sufix_param_name, torch_param)
|
|
|
|
dist.barrier()
|
2022-12-20 02:19:36 +00:00
|
|
|
|
2023-01-06 05:41:19 +00:00
|
|
|
return torch_model
|