[tensor] refactor parallel action (#1007)

* refactor parallel action

* polish unit tests
pull/1010/head
ver217 2022-05-20 20:19:58 +08:00 committed by GitHub
parent 9e3d602dba
commit a3b66f6def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 45 additions and 77 deletions

View File

@ -3,12 +3,12 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import distspec
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
@ -18,7 +18,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
parallel_action = mat2.spec.parallel_action
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
mat1 = reduce_grad(mat1, parallel_action.parallel_mode)
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if parallel_action.gather_out:
# All-Gather(Output)

View File

@ -1,10 +1,10 @@
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor
@ -17,7 +17,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
output_parallel = F.embedding(input_tensor,
@ -29,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse=sparse)
output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
@ -45,10 +44,9 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
@ -72,7 +70,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output

View File

@ -1,16 +1,14 @@
import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
@ -20,7 +18,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Output:P
partial_output = F.linear(input_tensor, weight)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
@ -34,15 +32,16 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
parallel_action = weight.spec.parallel_action
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(
output_parallel,
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
output = ColoTensor.from_torch_tensor(output_parallel,
spec=TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1],
[weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)))
if parallel_action.gather_out:
# All-Gather(Output)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))

View File

@ -28,7 +28,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
elif input_tensor.has_spec(): # Single Model Parallel Applied
if input_tensor.spec.is_1D_col():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)

View File

@ -33,3 +33,6 @@ class ColoParameter(ColoTensor):
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'

View File

@ -45,7 +45,7 @@ class ColoTensor(torch.Tensor):
self._spec = spec
def has_spec(self) -> bool:
return self._spec.num_action > 0
return self._spec.parallel_action is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL

View File

@ -1,26 +1,21 @@
import torch.distributed as dist
from enum import Enum
from typing import List
from colossalai.context.parallel_mode import ParallelMode
from typing import List, Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum):
TP1D = 0
ZeRO = 1
DP = 2
TP2D = 1
TP2P5D = 2
TP3D = 3
class ParallelAction(object):
def __init__(self,
priority=0,
compute_pattern=ComputePattern.DP,
parallel_mode=ParallelMode.DATA,
gather_out=True) -> None:
self.priority = priority
def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
self.parallel_mode = parallel_mode
self.gather_out = gather_out
@ -48,32 +43,9 @@ class TensorSpec(object):
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
self._parallel_action_list = parallel_action_list
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
self.parallel_action = parallel_action
self.dist_spec = dist_spec
self.sort()
@property
def parallel_action_list(self):
return self._parallel_action_list
@property
def num_action(self):
return len(self._parallel_action_list)
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
def sort(self):
if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern:
return parallel_action
return None
def get_process_group(self):
return self.dist_spec.process_group
@ -99,4 +71,4 @@ class TensorSpec(object):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.get_action_by_compute_pattern(compute_pattern) is not None
return self.parallel_action.compute_pattern == compute_pattern

View File

@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -26,7 +26,7 @@ def init_1d_row(weight):
def init_1d_col(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)

View File

@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
@ -26,7 +26,7 @@ def init_1d_row_spec(model):
def init_1d_col_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@ -20,19 +20,15 @@ from _utils import set_seed
def init_1d_row_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
gather_out=gather_out)
])
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
def init_1d_row_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
def init_1d_col_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)