[tensor] derive compute pattern from dist spec (#971)

* derive compute pattern from dist spec

* polish code
pull/981/head
ver217 2022-05-16 14:58:08 +08:00 committed by GitHub
parent 46bc95708f
commit c2fdc6a011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 79 additions and 65 deletions

View File

@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addbmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.spec.num_action == 1: # Single Model Parallel Applied
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
compute_patterns = mat2.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
elif ComputePattern.TP1DCol in compute_patterns:
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError

View File

@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg):
weight = ColoTensor.init_from_torch_tensor(weight)
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
return ColoTensor.init_from_torch_tensor(output)
elif weight.spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
elif ComputePattern.TP1DCol in compute_patterns:
elif weight.spec.is_1D_col():
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
else:
raise NotImplementedError

View File

@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
if bias is not None:
bias = bias.torch_tensor()
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
output = ColoTensor.init_from_torch_tensor(
output_parallel,
@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
bias = bias.torch_tensor()
if bias is not None:
bias = bias.torch_tensor()
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
elif ComputePattern.TP1DCol in compute_patterns:
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
else:
raise NotImplementedError

View File

@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
@ -27,13 +28,13 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
if isinstance(target, ColoTensor):
target = target.torch_tensor()
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
input_tensor.torch_tensor(), target, weight))
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
if input_tensor.spec.is_1Dcol():
return ColoTensor.init_from_torch_tensor(
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
if input_tensor.spec.is_1D_col():
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
target))
else:
raise NotImplementedError
else:

View File

@ -1,6 +1,7 @@
from enum import Enum
from torch.distributed import ProcessGroup
from typing import Optional, List
from numpy import prod
__all__ = ['replicate', 'shard']
@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
assert process_group is not None
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.size()
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum):
# TODO (ver217): remove TP1DRow_<ops>
TP1DRow = 0
TP1DCol = 9
TP1DRow_Linear = 1
TP1DCol_Linear = 2
TP1DRow_Embedding = 3
TP1DCol_Embedding = 4
TP1DRow_mm = 5
TP1DCol_mm = 6
ZeRO = 7
DP = 8
TP1D = 0
ZeRO = 1
DP = 2
class ParallelAction(object):
@ -45,14 +37,14 @@ class TensorSpec(object):
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1DRow_Linear.
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow_Linear.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
@ -90,10 +82,17 @@ class TensorSpec(object):
def is_gathered(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.size() == 1)
def is_1Dcol(self):
def is_1D_col(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_1D_row(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.get_action_by_compute_pattern(compute_pattern) is not None

View File

@ -40,7 +40,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def init_1d_row(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
def init_1d_col(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)

View File

@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@ -86,35 +86,43 @@ def set_seed(seed):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def init_1d_row_linear(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \
gather_out=gather_out)])
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
gather_out=gather_out)
])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_row_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -124,7 +132,7 @@ def run_1d_hybrid_tp(model_name):
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
@ -173,7 +181,7 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
model_torch.eval()
colo_optimizer_torch.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
@ -217,11 +225,11 @@ def run_1d_hybrid_tp(model_name):
assert torch.allclose(p1, p2)
else:
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
if p1.size(-1) < p2.size(-1): # col
if p1.size(-1) < p2.size(-1): # col
world_size = p2.size(-1) // p1.size(-1)
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
elif p1.size(0) < p2.size(0): # row
elif p1.size(0) < p2.size(0): # row
world_size = p2.size(0) // p1.size(0)
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
@ -376,7 +384,7 @@ def _run_pretrain_load():
if isinstance(param, ColoParameter):
c1 += 1
else:
c2 +=1
c2 += 1
dict_col[name] = param
assert c_ref == c1
assert c2 == 0
@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port):
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
# @parameterize('world_size', [1, 4])