diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 591848e42..62af9e2c2 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,3 +1,4 @@ +from .process_group import ProcessGroup from .tensor_spec import TensorSpec from .compute_spec import ComputeSpec, ComputePattern from .colo_tensor import ColoTensor @@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor from .dist_spec_mgr import DistSpecManager from .param_op_hook import ParamOpHook, ParamOpHookManager from . import distspec -from .process_group import ProcessGroup __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor', diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f36c18313..0d34368ec 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor): 1. directly init. >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate()) >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. - >>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), + >>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size), >>> dims=[0], >>> num_partitions=[world_size]) >>> tensor_spec = TensorSpec(shard_spec) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index c60e8127e..1f0a1bf87 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -5,7 +5,7 @@ from typing import List, Optional class ProcessGroup: """ Process Group contains group partition for Tensor Parallel and Data Parallel. - WARNING, the ProcessGroup must be used after torch.distributed.initialize() + NOTE, the ProcessGroup must be used after torch.distributed.initialize() args: rank: the global rank of the current process. ranks: List[int], a list of rank id belongings to this process group. @@ -15,16 +15,24 @@ class ProcessGroup: """ def __init__(self, - rank: int, - ranks: List[int], + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, backend: str = 'nccl', tp_degree: Optional[int] = None, dp_degree: Optional[int] = None) -> None: - self._rank = rank - self._rank_list = ranks + assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" + if rank is None: + self._rank = torch.distributed.get_rank() + else: + self._rank = rank + + if ranks is None: + self._rank_list = list(range(torch.distributed.get_world_size())) + else: + self._rank_list = ranks + self._backend = backend self._world_size = len(self._rank_list) - assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" if dp_degree is None and tp_degree is None: self._dp_degree = self._world_size diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index b5d1bc752..8ed6f34a1 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -11,11 +11,9 @@ from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.nn.optimizer import ColoOptimizer from functools import partial -from _utils import tensor_equal, tensor_shard_equal, set_seed +from _utils import tensor_shard_equal, set_seed def init_1d_row_linear(weight, pg: ProcessGroup): @@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + rank = torch.distributed.get_rank() set_seed(1) with ColoInitContext(device=get_current_device()): @@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name): for p1, p2 in zip(model.parameters(), model_torch.parameters()): p2.data.copy_(p1.data) - rank = gpc.get_local_rank(ParallelMode.GLOBAL) - world_size = gpc.get_world_size(ParallelMode.GLOBAL) - pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) if 'bert' == model_name: for name, p in model.named_parameters(): if not isinstance(p, ColoTensor): @@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + rank = torch.distributed.get_rank() set_seed(1) with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) - rank = gpc.get_local_rank(ParallelMode.GLOBAL) - world_size = gpc.get_world_size(ParallelMode.GLOBAL) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) set_seed(1) @@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str): data = data.to(get_current_device()) label = label.to(get_current_device()) - torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) - torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) + torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) # Bcast rank0 data to all processes if criterion: