mirror of https://github.com/hpcaitech/ColossalAI
[tensor] remove gpc in tensor tests (#1186)
parent
372f791444
commit
c463f8adf9
|
@ -1,3 +1,4 @@
|
||||||
|
from .process_group import ProcessGroup
|
||||||
from .tensor_spec import TensorSpec
|
from .tensor_spec import TensorSpec
|
||||||
from .compute_spec import ComputeSpec, ComputePattern
|
from .compute_spec import ComputeSpec, ComputePattern
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
|
@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||||
from . import distspec
|
from . import distspec
|
||||||
from .process_group import ProcessGroup
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
|
||||||
|
|
|
@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
|
||||||
1. directly init.
|
1. directly init.
|
||||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||||
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
|
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
|
||||||
>>> dims=[0],
|
>>> dims=[0],
|
||||||
>>> num_partitions=[world_size])
|
>>> num_partitions=[world_size])
|
||||||
>>> tensor_spec = TensorSpec(shard_spec)
|
>>> tensor_spec = TensorSpec(shard_spec)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
||||||
class ProcessGroup:
|
class ProcessGroup:
|
||||||
"""
|
"""
|
||||||
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
||||||
WARNING, the ProcessGroup must be used after torch.distributed.initialize()
|
NOTE, the ProcessGroup must be used after torch.distributed.initialize()
|
||||||
args:
|
args:
|
||||||
rank: the global rank of the current process.
|
rank: the global rank of the current process.
|
||||||
ranks: List[int], a list of rank id belongings to this process group.
|
ranks: List[int], a list of rank id belongings to this process group.
|
||||||
|
@ -15,16 +15,24 @@ class ProcessGroup:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rank: int,
|
rank: Optional[int] = None,
|
||||||
ranks: List[int],
|
ranks: Optional[List[int]] = None,
|
||||||
backend: str = 'nccl',
|
backend: str = 'nccl',
|
||||||
tp_degree: Optional[int] = None,
|
tp_degree: Optional[int] = None,
|
||||||
dp_degree: Optional[int] = None) -> None:
|
dp_degree: Optional[int] = None) -> None:
|
||||||
self._rank = rank
|
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||||
self._rank_list = ranks
|
if rank is None:
|
||||||
|
self._rank = torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
self._rank = rank
|
||||||
|
|
||||||
|
if ranks is None:
|
||||||
|
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||||
|
else:
|
||||||
|
self._rank_list = ranks
|
||||||
|
|
||||||
self._backend = backend
|
self._backend = backend
|
||||||
self._world_size = len(self._rank_list)
|
self._world_size = len(self._rank_list)
|
||||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
|
||||||
|
|
||||||
if dp_degree is None and tp_degree is None:
|
if dp_degree is None and tp_degree is None:
|
||||||
self._dp_degree = self._world_size
|
self._dp_degree = self._world_size
|
||||||
|
|
|
@ -11,11 +11,9 @@ from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.nn.optimizer import ColoOptimizer
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
from _utils import tensor_shard_equal, set_seed
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||||
|
@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
|
||||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||||
p2.data.copy_(p1.data)
|
p2.data.copy_(p1.data)
|
||||||
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
rank = torch.distributed.get_rank()
|
||||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
if 'bert' == model_name:
|
if 'bert' == model_name:
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
|
@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
rank = torch.distributed.get_rank()
|
||||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
|
||||||
data = data.to(get_current_device())
|
data = data.to(get_current_device())
|
||||||
label = label.to(get_current_device())
|
label = label.to(get_current_device())
|
||||||
|
|
||||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||||
|
|
||||||
# Bcast rank0 data to all processes
|
# Bcast rank0 data to all processes
|
||||||
if criterion:
|
if criterion:
|
||||||
|
|
Loading…
Reference in New Issue