diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index f58976231..844439cde 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist from colossalai.gemini.chunk import Chunk +from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device @@ -19,3 +20,30 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) return total_temp + + +def _add_param(model, name, param): + name_list = name.split('.') + module = model._modules[name_list[0]] + for i in range(1, len(name_list) - 1): + module = module._modules[name_list[i]] + module._parameters[name_list[-1]] = param + + +def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module: + """convert_to_torch_module + + Args: + gemini_ddp_model (GeminiDDP): a gemini ddp model + + Returns: + torch.nn.Module: a torch model contains the params of gemini_ddp_model + """ + module = gemini_ddp_model.module + + for n, p in module.named_parameters(): + if isinstance(p, ColoTensor): + p.to_replicate_() + _add_param(module, n, p.data) + + return module diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index c9e48a453..7ecb407b5 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -103,7 +103,6 @@ class ColoTensor(torch.Tensor): self.process_group = spec.pg self._type = TensorType.NONMODEL - self._graph_node = None def has_compute_spec(self) -> bool: return self.compute_spec is not None diff --git a/tests/test_gemini/update/test_convert_torch_module.py b/tests/test_gemini/update/test_convert_torch_module.py new file mode 100644 index 000000000..c0fd94b40 --- /dev/null +++ b/tests/test_gemini/update/test_convert_torch_module.py @@ -0,0 +1,48 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel.utils import convert_to_torch_module +from colossalai.tensor import ColoTensor +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('model_name', ['resnet18', 'bert']) +def run_convert_torch_module(model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, _, _, _, _ = get_components_func() + + with ColoInitContext(device='cpu'): + model = model_builder(checkpoint=False) + + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) + + pytorch_model = convert_to_torch_module(model) + + for n, p in pytorch_model.named_parameters(): + assert not isinstance(p, ColoTensor) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_convert_torch_module() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_convert_torch_module(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_convert_torch_module(2)