[tensor] remove gpc in tensor tests (#1186)

pull/1188/head
Jiarui Fang 2022-06-29 14:08:40 +08:00 committed by GitHub
parent 372f791444
commit c463f8adf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 20 deletions

View File

@ -1,3 +1,4 @@
from .process_group import ProcessGroup
from .tensor_spec import TensorSpec from .tensor_spec import TensorSpec
from .compute_spec import ComputeSpec, ComputePattern from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from . import distspec from . import distspec
from .process_group import ProcessGroup
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',

View File

@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
1. directly init. 1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), >>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],
>>> num_partitions=[world_size]) >>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec) >>> tensor_spec = TensorSpec(shard_spec)

View File

@ -5,7 +5,7 @@ from typing import List, Optional
class ProcessGroup: class ProcessGroup:
""" """
Process Group contains group partition for Tensor Parallel and Data Parallel. Process Group contains group partition for Tensor Parallel and Data Parallel.
WARNING, the ProcessGroup must be used after torch.distributed.initialize() NOTE, the ProcessGroup must be used after torch.distributed.initialize()
args: args:
rank: the global rank of the current process. rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group. ranks: List[int], a list of rank id belongings to this process group.
@ -15,16 +15,24 @@ class ProcessGroup:
""" """
def __init__(self, def __init__(self,
rank: int, rank: Optional[int] = None,
ranks: List[int], ranks: Optional[List[int]] = None,
backend: str = 'nccl', backend: str = 'nccl',
tp_degree: Optional[int] = None, tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None: dp_degree: Optional[int] = None) -> None:
self._rank = rank assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
self._rank_list = ranks if rank is None:
self._rank = torch.distributed.get_rank()
else:
self._rank = rank
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
else:
self._rank_list = ranks
self._backend = backend self._backend = backend
self._world_size = len(self._rank_list) self._world_size = len(self._rank_list)
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if dp_degree is None and tp_degree is None: if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size self._dp_degree = self._world_size

View File

@ -11,11 +11,9 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_shard_equal, set_seed
def init_1d_row_linear(weight, pg: ProcessGroup): def init_1d_row_linear(weight, pg: ProcessGroup):
@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear # A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) rank = torch.distributed.get_rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
for p1, p2 in zip(model.parameters(), model_torch.parameters()): for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
rank = gpc.get_local_rank(ParallelMode.GLOBAL) rank = torch.distributed.get_rank()
world_size = gpc.get_world_size(ParallelMode.GLOBAL) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
if 'bert' == model_name: if 'bert' == model_name:
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear # A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) rank = torch.distributed.get_rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
rank = gpc.get_local_rank(ParallelMode.GLOBAL) rank = torch.distributed.get_rank()
world_size = gpc.get_world_size(ParallelMode.GLOBAL) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
set_seed(1) set_seed(1)
@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes # Bcast rank0 data to all processes
if criterion: if criterion: