mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] distributed view supports inter-process hybrid parallel (#1169)
parent
9e1daa63d2
commit
aa7bef73d4
|
@ -2,3 +2,5 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.context.parallel_context import global_context
|
||||
|
||||
__all__ = ['global_context']
|
|
@ -68,11 +68,11 @@ def colo_addmm(input_tensor: GeneralTensor,
|
|||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not mat2.has_compute_spec(): # No Model Parallel Applied
|
||||
assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||
assert mat2.tensor_spec.is_replicate(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
|
||||
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered():
|
||||
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_replicate():
|
||||
mode = 'row'
|
||||
elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
|
||||
or input_tensor.tensor_spec.is_1D_row()):
|
||||
|
|
|
@ -51,7 +51,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
num_embeddings_per_partition = weight.size(0)
|
||||
num_embeddings_per_partition = weight.size_base(0)
|
||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
||||
|
@ -115,7 +115,7 @@ def colo_embedding(input_tensor: GeneralTensor,
|
|||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(
|
||||
F.embedding(input_tensor,
|
||||
weight,
|
||||
|
|
|
@ -90,7 +90,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
|||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(
|
||||
F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
|
|
|
@ -67,17 +67,17 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
|||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()):
|
||||
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_replicate()):
|
||||
mode = 'row'
|
||||
elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
|
||||
or bias.tensor_spec.is_1D_col()):
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}")
|
||||
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
label_smoothing: float = 0.0):
|
||||
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
|
||||
|
||||
if input_tensor.tensor_spec.is_gathered(): # Input is gathered
|
||||
if input_tensor.tensor_spec.is_replicate(): # Input is gathered
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
weight=weight,
|
||||
|
|
|
@ -114,7 +114,7 @@ class Chunk:
|
|||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||
# otherwise set its storage size to 0 to reduce memory consumption
|
||||
if self.is_src_rank:
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||
tensor_state = TensorState.HOLD
|
||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
|
|
|
@ -101,3 +101,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
|
||||
|
||||
def view(self, *args) -> 'ColoTensor':
|
||||
return super().view_base(*args)
|
||||
|
||||
def size(self, *args, **kwargs) -> torch.Size:
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
return super().size_base(*args, **kwargs)
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec
|
|||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _convert_output(output):
|
||||
|
@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor):
|
|||
def tensor_spec(self) -> TensorSpec:
|
||||
return self._tensor_spec
|
||||
|
||||
@tensor_spec.setter
|
||||
def tensor_spec(self, tenseor_spec: TensorSpec):
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def set_tensor_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
|
@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor):
|
|||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
return tensor
|
||||
|
||||
##### override builtin functions which must use tensor in replicate placement ####
|
||||
|
||||
def view_base(self, *args) -> 'ColoTensor':
|
||||
return super().view(*args)
|
||||
|
||||
def size_base(self, *args, **kwargs) -> torch.Size:
|
||||
return super().size(*args, **kwargs)
|
||||
|
||||
def view(self, *args) -> 'ColoTensor':
|
||||
"""override the torch buildin view()
|
||||
the args passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
return super().view(*args)
|
||||
# TODO(jiaruifang) check why this not work
|
||||
# self.data = self.to_replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
|
||||
def size(self, args: Optional[int] = None):
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
if args is not None:
|
||||
return super().size(args)
|
||||
else:
|
||||
return super().size()
|
||||
|
||||
spec = self.tensor_spec.dist_spec
|
||||
dims = spec.dims
|
||||
num_partitions = spec.num_partitions
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
|
||||
size_list = list(super().size())
|
||||
for dim, num_partition in zip(dims, num_partitions):
|
||||
size_list[dim] *= num_partition
|
||||
if args is not None:
|
||||
return size_list[args]
|
||||
else:
|
||||
return torch.Size(size_list)
|
||||
|
|
|
@ -68,6 +68,7 @@ class DistSpecManager:
|
|||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
|
||||
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
|
||||
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
|
||||
idx %= num_parts
|
||||
|
|
|
@ -26,7 +26,7 @@ class TensorSpec(object):
|
|||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
def is_gathered(self):
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
|
|
|
@ -101,4 +101,4 @@ def test_gpt(world_size, use_ddp):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4, False)
|
||||
test_gpt(4, True)
|
||||
|
|
|
@ -60,6 +60,19 @@ def test_operand():
|
|||
#### Test Distributed init a Colotensor
|
||||
|
||||
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref, TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[2])))
|
||||
|
||||
assert t.size()[0] == 4 * world_size
|
||||
assert t.size(1) == 5
|
||||
assert t.size() == torch.Size([4 * world_size, 5])
|
||||
|
||||
t = t.view(4 * 5 * world_size)
|
||||
assert t.shape == torch.Size([4 * 5 * world_size])
|
||||
|
||||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
print(gpc.get_group(ParallelMode.DATA).size())
|
||||
|
@ -77,20 +90,21 @@ def _run_tensor_replicated_init(world_size):
|
|||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def run_tensor_init(rank, world_size, port):
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def _test_dist_init(world_size):
|
||||
run_func = partial(run_tensor_init, world_size=world_size, port=free_port())
|
||||
def _test_dist_cases(world_size):
|
||||
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# _test_dist_init(4)
|
||||
test_new()
|
||||
_test_dist_cases(2)
|
||||
|
|
Loading…
Reference in New Issue