import torch import torch.distributed as dist from colossalai.gemini.chunk import Chunk from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device def get_temp_total_chunk_on_cuda(chunk: Chunk): if chunk.is_gathered: return chunk.cuda_global_chunk if chunk.cuda_shard is not None: shard_temp = chunk.cuda_shard else: shard_temp = chunk.cpu_shard.to(get_current_device()) total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) return total_temp def _add_param(model, name, param): name_list = name.split('.') module = model._modules[name_list[0]] for i in range(1, len(name_list) - 1): module = module._modules[name_list[i]] module._parameters[name_list[-1]] = param def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: """convert_to_torch_module Args: gemini_ddp_model (GeminiDDP): a gemini ddp model Returns: torch.nn.Module: a torch model contains the params of gemini_ddp_model """ module = gemini_ddp_model.module for n, p in module.named_parameters(): if isinstance(p, ColoTensor): p.to_replicate_() _add_param(module, n, p.data) return module